r""" Set non-SQL options for the connection which take effect
during execution.
- The method returns a copy of this :class:`_engine.Connection`
- which references
- the same underlying DBAPI connection, but also defines the given
- execution options which will take effect for a call to
- :meth:`execute`. As the new :class:`_engine.Connection`
- references the same
- underlying resource, it's usually a good idea to ensure that the copies
- will be discarded immediately, which is implicit if used as in::
+ For a "future" style connection, this method returns this same
+ :class:`_future.Connection` object with the new options added.
+
+ For a legacy connection, this method returns a copy of this
+ :class:`_engine.Connection` which references the same underlying DBAPI
+ connection, but also defines the given execution options which will
+ take effect for a call to
+ :meth:`execute`. As the new :class:`_engine.Connection` references the
+ same underlying resource, it's usually a good idea to ensure that
+ the copies will be discarded immediately, which is implicit if used
+ as in::
result = connection.execution_options(stream_results=True).\
execute(stmt)
"""Invalidate the underlying DBAPI connection associated with
this :class:`_engine.Connection`.
- The underlying DBAPI connection is literally closed (if
- possible), and is discarded. Its source connection pool will
- typically lazily create a new connection to replace it.
+ An attempt will be made to close the underlying DBAPI connection
+ immediately; however if this operation fails, the error is logged
+ but not raised. The connection is then discarded whether or not
+ close() succeeded.
Upon the next use (where "use" typically means using the
:meth:`_engine.Connection.execute` method or similar),
will at the connection pool level invoke the
:meth:`_events.PoolEvents.invalidate` event.
+ :param exception: an optional ``Exception`` instance that's the
+ reason for the invalidation. is passed along to event handlers
+ and logging functions.
+
.. seealso::
:ref:`pool_connection_invalidation`
dispatch_cls._event_names.append(ls.name)
if getattr(cls, "_dispatch_target", None):
- cls._dispatch_target.dispatch = dispatcher(cls)
+ the_cls = cls._dispatch_target
+ if (
+ hasattr(the_cls, "__slots__")
+ and "_slots_dispatch" in the_cls.__slots__
+ ):
+ cls._dispatch_target.dispatch = slots_dispatcher(cls)
+ else:
+ cls._dispatch_target.dispatch = dispatcher(cls)
def _remove_dispatcher(cls):
def __get__(self, obj, cls):
if obj is None:
return self.dispatch
- obj.__dict__["dispatch"] = disp = self.dispatch._for_instance(obj)
+
+ disp = self.dispatch._for_instance(obj)
+ try:
+ obj.__dict__["dispatch"] = disp
+ except AttributeError as ae:
+ util.raise_(
+ TypeError(
+ "target %r doesn't have __dict__, should it be "
+ "defining _slots_dispatch?" % (obj,)
+ ),
+ replace_context=ae,
+ )
+ return disp
+
+
+class slots_dispatcher(dispatcher):
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self.dispatch
+
+ if hasattr(obj, "_slots_dispatch"):
+ return obj._slots_dispatch
+
+ disp = self.dispatch._for_instance(obj)
+ obj._slots_dispatch = disp
return disp
from .engine import AsyncEngine # noqa
from .engine import AsyncTransaction # noqa
from .engine import create_async_engine # noqa
+from .events import AsyncConnectionEvents # noqa
+from .events import AsyncSessionEvents # noqa
from .result import AsyncMappingResult # noqa
from .result import AsyncResult # noqa
from .result import AsyncScalarResult # noqa
from .result import AsyncResult
from ... import exc
from ... import util
-from ...engine import Connection
from ...engine import create_engine as _create_engine
-from ...engine import Engine
from ...engine import Result
from ...engine import Transaction
-from ...engine.base import OptionEngineMixin
+from ...future import Connection
+from ...future import Engine
from ...sql import Executable
from ...util.concurrency import greenlet_spawn
return AsyncEngine(sync_engine)
-class AsyncConnection(StartableContext):
+class AsyncConnectable:
+ __slots__ = "_slots_dispatch"
+
+
+@util.create_proxy_methods(
+ Connection,
+ ":class:`_future.Connection`",
+ ":class:`_asyncio.AsyncConnection`",
+ classmethods=[],
+ methods=[],
+ attributes=[
+ "closed",
+ "invalidated",
+ "dialect",
+ "default_isolation_level",
+ ],
+)
+class AsyncConnection(StartableContext, AsyncConnectable):
"""An asyncio proxy for a :class:`_engine.Connection`.
:class:`_asyncio.AsyncConnection` is acquired using the
""" # noqa
+ # AsyncConnection is a thin proxy; no state should be added here
+ # that is not retrievable from the "sync" engine / connection, e.g.
+ # current transaction, info, etc. It should be possible to
+ # create a new AsyncConnection that matches this one given only the
+ # "sync" elements.
__slots__ = (
"sync_engine",
"sync_connection",
)
def __init__(
- self, sync_engine: Engine, sync_connection: Optional[Connection] = None
+ self,
+ async_engine: "AsyncEngine",
+ sync_connection: Optional[Connection] = None,
):
- self.sync_engine = sync_engine
+ self.engine = async_engine
+ self.sync_engine = async_engine.sync_engine
self.sync_connection = sync_connection
async def start(self):
self.sync_connection = await (greenlet_spawn(self.sync_engine.connect))
return self
+ @property
+ def connection(self):
+ """Not implemented for async; call
+ :meth:`_asyncio.AsyncConnection.get_raw_connection`.
+
+ """
+ raise exc.InvalidRequestError(
+ "AsyncConnection.connection accessor is not implemented as the "
+ "attribute may need to reconnect on an invalidated connection. "
+ "Use the get_raw_connection() method."
+ )
+
+ async def get_raw_connection(self):
+ """Return the pooled DBAPI-level connection in use by this
+ :class:`_asyncio.AsyncConnection`.
+
+ This is typically the SQLAlchemy connection-pool proxied connection
+ which then has an attribute .connection that refers to the actual
+ DBAPI-level connection.
+ """
+ conn = self._sync_connection()
+
+ return await greenlet_spawn(getattr, conn, "connection")
+
+ @property
+ def _proxied(self):
+ return self.sync_connection
+
def _sync_connection(self):
if not self.sync_connection:
self._raise_for_not_started()
self._sync_connection()
return AsyncTransaction(self, nested=True)
+ async def invalidate(self, exception=None):
+ """Invalidate the underlying DBAPI connection associated with
+ this :class:`_engine.Connection`.
+
+ See the method :meth:`_engine.Connection.invalidate` for full
+ detail on this method.
+
+ """
+
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.invalidate, exception=exception)
+
+ async def get_isolation_level(self):
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.get_isolation_level)
+
+ async def set_isolation_level(self):
+ conn = self._sync_connection()
+ return await greenlet_spawn(conn.get_isolation_level)
+
+ async def execution_options(self, **opt):
+ r"""Set non-SQL options for the connection which take effect
+ during execution.
+
+ This returns this :class:`_asyncio.AsyncConnection` object with
+ the new options added.
+
+ See :meth:`_future.Connection.execution_options` for full details
+ on this method.
+
+ """
+
+ conn = self._sync_connection()
+ c2 = await greenlet_spawn(conn.execution_options, **opt)
+ assert c2 is conn
+ return self
+
async def commit(self):
"""Commit the transaction that is currently in progress.
await self.close()
-class AsyncEngine:
+@util.create_proxy_methods(
+ Engine,
+ ":class:`_future.Engine`",
+ ":class:`_asyncio.AsyncEngine`",
+ classmethods=[],
+ methods=[
+ "clear_compiled_cache",
+ "update_execution_options",
+ "get_execution_options",
+ ],
+ attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
+)
+class AsyncEngine(AsyncConnectable):
"""An asyncio proxy for a :class:`_engine.Engine`.
:class:`_asyncio.AsyncEngine` is acquired using the
""" # noqa
- __slots__ = ("sync_engine",)
+ # AsyncEngine is a thin proxy; no state should be added here
+ # that is not retrievable from the "sync" engine / connection, e.g.
+ # current transaction, info, etc. It should be possible to
+ # create a new AsyncEngine that matches this one given only the
+ # "sync" elements.
+ __slots__ = ("sync_engine", "_proxied")
_connection_cls = AsyncConnection
await self.conn.close()
def __init__(self, sync_engine: Engine):
- self.sync_engine = sync_engine
+ self.sync_engine = self._proxied = sync_engine
def begin(self):
"""Return a context manager which when entered will deliver an
"""
- return self._connection_cls(self.sync_engine)
+ return self._connection_cls(self)
async def raw_connection(self) -> Any:
"""Return a "raw" DBAPI connection from the connection pool.
"""
return await greenlet_spawn(self.sync_engine.raw_connection)
+ def execution_options(self, **opt):
+ """Return a new :class:`_asyncio.AsyncEngine` that will provide
+ :class:`_asyncio.AsyncConnection` objects with the given execution
+ options.
+
+ Proxied from :meth:`_future.Engine.execution_options`. See that
+ method for details.
+
+ """
+
+ return AsyncEngine(self.sync_engine.execution_options(**opt))
-class AsyncOptionEngine(OptionEngineMixin, AsyncEngine):
- pass
+ async def dispose(self):
+ """Dispose of the connection pool used by this
+ :class:`_asyncio.AsyncEngine`.
+ This will close all connection pool connections that are
+ **currently checked in**. See the documentation for the underlying
+ :meth:`_future.Engine.dispose` method for further notes.
+
+ .. seealso::
+
+ :meth:`_future.Engine.dispose`
+
+ """
-AsyncEngine._option_cls = AsyncOptionEngine
+ return await greenlet_spawn(self.sync_engine.dispose)
class AsyncTransaction(StartableContext):
--- /dev/null
+from .engine import AsyncConnectable
+from .session import AsyncSession
+from ...engine import events as engine_event
+from ...orm import events as orm_event
+
+
+class AsyncConnectionEvents(engine_event.ConnectionEvents):
+ _target_class_doc = "SomeEngine"
+ _dispatch_target = AsyncConnectable
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncEngine.sync_engine or "
+ "AsyncConnection.sync_connection attributes."
+ )
+
+
+class AsyncSessionEvents(orm_event.SessionEvents):
+ _target_class_doc = "SomeSession"
+ _dispatch_target = AsyncSession
+
+ @classmethod
+ def _listen(cls, event_key, retval=False):
+ raise NotImplementedError(
+ "asynchronous events are not implemented at this time. Apply "
+ "synchronous listeners to the AsyncSession.sync_session."
+ )
from typing import Any
from typing import Callable
-from typing import List
from typing import Mapping
from typing import Optional
from ...util.concurrency import greenlet_spawn
+@util.create_proxy_methods(
+ Session,
+ ":class:`_orm.Session`",
+ ":class:`_asyncio.AsyncSession`",
+ classmethods=["object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "delete",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "get_bind",
+ "is_modified",
+ ],
+ attributes=[
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ ],
+)
class AsyncSession:
"""Asyncio version of :class:`_orm.Session`.
"""
+ __slots__ = (
+ "binds",
+ "bind",
+ "sync_session",
+ "_proxied",
+ "_slots_dispatch",
+ )
+
+ dispatch = None
+
def __init__(
self,
bind: AsyncEngine = None,
):
kw["future"] = True
if bind:
+ self.bind = engine
bind = engine._get_sync_engine(bind)
if binds:
+ self.binds = binds
binds = {
key: engine._get_sync_engine(b) for key, b in binds.items()
}
- self.sync_session = Session(bind=bind, binds=binds, **kw)
-
- def add(self, instance: object) -> None:
- """Place an object in this :class:`_asyncio.AsyncSession`.
-
- .. seealso::
-
- :meth:`_orm.Session.add`
-
- """
- self.sync_session.add(instance)
-
- def add_all(self, instances: List[object]) -> None:
- """Add the given collection of instances to this
- :class:`_asyncio.AsyncSession`."""
-
- self.sync_session.add_all(instances)
-
- def expire_all(self):
- """Expires all persistent instances within this Session.
-
- See :meth:`_orm.Session.expire_all` for usage details.
-
- """
- self.sync_session.expire_all()
-
- def expire(self, instance, attribute_names=None):
- """Expire the attributes on an instance.
-
- See :meth:`._orm.Session.expire` for usage details.
-
- """
- self.sync_session.expire()
+ self.sync_session = self._proxied = Session(
+ bind=bind, binds=binds, **kw
+ )
async def refresh(
self, instance, attribute_names=None, with_for_update=None
:class:`.Session` object's transactional state.
"""
+
+ # POSSIBLY TODO: here, we see that the sync engine / connection
+ # that are generated from AsyncEngine / AsyncConnection don't
+ # provide any backlink from those sync objects back out to the
+ # async ones. it's not *too* big a deal since AsyncEngine/Connection
+ # are just proxies and all the state is actually in the sync
+ # version of things. However! it has to stay that way :)
sync_connection = await greenlet_spawn(self.sync_session.connection)
- return engine.AsyncConnection(sync_connection.engine, sync_connection)
+ return engine.AsyncConnection(
+ engine.AsyncEngine(sync_connection.engine), sync_connection
+ )
def begin(self, **kw):
"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
return AsyncSessionTransaction(self, nested=True)
async def rollback(self):
+ """Rollback the current transaction in progress."""
return await greenlet_spawn(self.sync_session.rollback)
async def commit(self):
+ """Commit the current transaction in progress."""
return await greenlet_spawn(self.sync_session.commit)
async def close(self):
+ """Close this :class:`_asyncio.AsyncSession`."""
return await greenlet_spawn(self.sync_session.close)
+ @classmethod
+ async def close_all(self):
+ """Close all :class:`_asyncio.AsyncSession` sessions."""
+ return await greenlet_spawn(self.sync_session.close_all)
+
async def __aenter__(self):
return self
elif isinstance(target, Session):
return target
else:
- return None
+ # allows alternate SessionEvents-like-classes to be consulted
+ return event.Events._accept_with(target)
@classmethod
def _listen(cls, event_key, raw=False, restore_load_context=False, **kw):
from . import exc as orm_exc
from .session import Session
from .. import exc as sa_exc
+from ..util import create_proxy_methods
from ..util import ScopedRegistry
from ..util import ThreadLocalRegistry
from ..util import warn
-
__all__ = ["scoped_session"]
+@create_proxy_methods(
+ Session,
+ ":class:`_orm.Session`",
+ ":class:`_orm.scoping.scoped_session`",
+ classmethods=["close_all", "object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "begin",
+ "begin_nested",
+ "close",
+ "commit",
+ "connection",
+ "delete",
+ "execute",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "flush",
+ "get_bind",
+ "is_modified",
+ "bulk_save_objects",
+ "bulk_insert_mappings",
+ "bulk_update_mappings",
+ "merge",
+ "query",
+ "refresh",
+ "rollback",
+ "scalar",
+ ],
+ attributes=[
+ "bind",
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ "autocommit",
+ ],
+)
class scoped_session(object):
"""Provides scoped management of :class:`.Session` objects.
else:
self.registry = ThreadLocalRegistry(session_factory)
+ @property
+ def _proxied(self):
+ return self.registry()
+
def __call__(self, **kw):
r"""Return the current :class:`.Session`, creating it
using the :attr:`.scoped_session.session_factory` if not present.
ScopedSession = scoped_session
"""Old name for backwards compatibility."""
-
-
-def instrument(name):
- def do(self, *args, **kwargs):
- return getattr(self.registry(), name)(*args, **kwargs)
-
- return do
-
-
-for meth in Session.public_methods:
- setattr(scoped_session, meth, instrument(meth))
-
-
-def makeprop(name):
- def set_(self, attr):
- setattr(self.registry(), name, attr)
-
- def get(self):
- return getattr(self.registry(), name)
-
- return property(get, set_)
-
-
-for prop in (
- "bind",
- "dirty",
- "deleted",
- "new",
- "identity_map",
- "is_active",
- "autoflush",
- "no_autoflush",
- "info",
- "autocommit",
-):
- setattr(scoped_session, prop, makeprop(prop))
-
-
-def clslevel(name):
- def do(cls, *args, **kwargs):
- return getattr(Session, name)(*args, **kwargs)
-
- return classmethod(do)
-
-
-for prop in ("close_all", "object_session", "identity_key"):
- setattr(scoped_session, prop, clslevel(prop))
"""
- public_methods = (
- "__contains__",
- "__iter__",
- "add",
- "add_all",
- "begin",
- "begin_nested",
- "close",
- "commit",
- "connection",
- "delete",
- "execute",
- "expire",
- "expire_all",
- "expunge",
- "expunge_all",
- "flush",
- "get_bind",
- "is_modified",
- "bulk_save_objects",
- "bulk_insert_mappings",
- "bulk_update_mappings",
- "merge",
- "query",
- "refresh",
- "rollback",
- "scalar",
- )
-
@util.deprecated_params(
autocommit=(
"2.0",
will unexpire attributes on access.
"""
- state = attributes.instance_state(obj)
+ try:
+ state = attributes.instance_state(obj)
+ except exc.NO_STATE as err:
+ util.raise_(
+ exc.UnmappedInstanceError(obj),
+ replace_context=err,
+ )
+
to_attach = self._before_attach(state, obj)
state._load_pending = True
if to_attach:
"Soft " if soft else "",
self.connection,
)
+
if soft:
self._soft_invalidate_time = time.time()
else:
if add_positional_parameters:
spec.args.extend(add_positional_parameters)
- metadata = dict(target="target", fn="fn", name=fn.__name__)
+ metadata = dict(target="target", fn="__fn", name=fn.__name__)
metadata.update(format_argspec_plus(spec, grouped=False))
code = (
"""\
% metadata
)
decorated = _exec_code_in_env(
- code, {"target": target, "fn": fn}, fn.__name__
+ code, {"target": target, "__fn": fn}, fn.__name__
)
if not add_positional_parameters:
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
from .langhelpers import constructor_copy # noqa
from .langhelpers import constructor_key # noqa
from .langhelpers import counter # noqa
+from .langhelpers import create_proxy_methods # noqa
from .langhelpers import decode_slice # noqa
from .langhelpers import decorator # noqa
from .langhelpers import dictlike_iteritems # noqa
modules, classes, hierarchies, attributes, functions, and methods.
"""
+
from functools import update_wrapper
import hashlib
import inspect
passed positionally.
apply_kw
Like apply_pos, except keyword-ish args are passed as keywords.
+ apply_pos_proxied
+ Like apply_pos but omits the self/cls argument
Example::
spec = fn
args = compat.inspect_formatargspec(*spec)
+
+ apply_pos = compat.inspect_formatargspec(
+ spec[0], spec[1], spec[2], None, spec[4]
+ )
+
if spec[0]:
self_arg = spec[0][0]
+
+ apply_pos_proxied = compat.inspect_formatargspec(
+ spec[0][1:], spec[1], spec[2], None, spec[4]
+ )
+
elif spec[1]:
+ # im not sure what this is
self_arg = "%s[0]" % spec[1]
+
+ apply_pos_proxied = apply_pos
else:
self_arg = None
+ apply_pos_proxied = apply_pos
- apply_pos = compat.inspect_formatargspec(
- spec[0], spec[1], spec[2], None, spec[4]
- )
num_defaults = 0
if spec[3]:
num_defaults += len(spec[3])
self_arg=self_arg,
apply_pos=apply_pos,
apply_kw=apply_kw,
+ apply_pos_proxied=apply_pos_proxied,
)
else:
return dict(
self_arg=self_arg,
apply_pos=apply_pos[1:-1],
apply_kw=apply_kw[1:-1],
+ apply_pos_proxied=apply_pos_proxied[1:-1],
)
"""
if method is object.__init__:
- args = grouped and "(self)" or "self"
+ args = "(self)" if grouped else "self"
+ proxied = "()" if grouped else ""
else:
try:
return format_argspec_plus(method, grouped=grouped)
except TypeError:
args = (
- grouped
- and "(self, *args, **kwargs)"
- or "self, *args, **kwargs"
+ "(self, *args, **kwargs)"
+ if grouped
+ else "self, *args, **kwargs"
)
- return dict(self_arg="self", args=args, apply_pos=args, apply_kw=args)
+ proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
+ return dict(
+ self_arg="self",
+ args=args,
+ apply_pos=args,
+ apply_kw=args,
+ apply_pos_proxied=proxied,
+ )
+
+
+def create_proxy_methods(
+ target_cls,
+ target_cls_sphinx_name,
+ proxy_cls_sphinx_name,
+ classmethods=(),
+ methods=(),
+ attributes=(),
+):
+ """A class decorator that will copy attributes to a proxy class.
+
+ The class to be instrumented must define a single accessor "_proxied".
+
+ """
+
+ def decorate(cls):
+ def instrument(name, clslevel=False):
+ fn = getattr(target_cls, name)
+ spec = compat.inspect_getfullargspec(fn)
+ env = {}
+
+ spec = _update_argspec_defaults_into_env(spec, env)
+ caller_argspec = format_argspec_plus(spec, grouped=False)
+
+ metadata = {
+ "name": fn.__name__,
+ "apply_pos_proxied": caller_argspec["apply_pos_proxied"],
+ "args": caller_argspec["args"],
+ "self_arg": caller_argspec["self_arg"],
+ }
+
+ if clslevel:
+ code = (
+ "def %(name)s(%(args)s):\n"
+ " return target_cls.%(name)s(%(apply_pos_proxied)s)"
+ % metadata
+ )
+ env["target_cls"] = target_cls
+ else:
+ code = (
+ "def %(name)s(%(args)s):\n"
+ " return %(self_arg)s._proxied.%(name)s(%(apply_pos_proxied)s)" # noqa E501
+ % metadata
+ )
+
+ proxy_fn = _exec_code_in_env(code, env, fn.__name__)
+ proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+ proxy_fn.__doc__ = inject_docstring_text(
+ fn.__doc__,
+ ".. container:: class_bases\n\n "
+ "Proxied for the %s class on behalf of the %s class."
+ % (target_cls_sphinx_name, proxy_cls_sphinx_name),
+ 1,
+ )
+
+ if clslevel:
+ proxy_fn = classmethod(proxy_fn)
+
+ return proxy_fn
+
+ def makeprop(name):
+ attr = target_cls.__dict__.get(name, None)
+
+ if attr is not None:
+ doc = inject_docstring_text(
+ attr.__doc__,
+ ".. container:: class_bases\n\n "
+ "Proxied for the %s class on behalf of the %s class."
+ % (
+ target_cls_sphinx_name,
+ proxy_cls_sphinx_name,
+ ),
+ 1,
+ )
+ else:
+ doc = None
+
+ code = (
+ "def set_(self, attr):\n"
+ " self._proxied.%(name)s = attr\n"
+ "def get(self):\n"
+ " return self._proxied.%(name)s\n"
+ "get.__doc__ = doc\n"
+ "getset = property(get, set_)"
+ ) % {"name": name}
+
+ getset = _exec_code_in_env(code, {"doc": doc}, "getset")
+
+ return getset
+
+ for meth in methods:
+ if hasattr(cls, meth):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, meth)
+ )
+ setattr(cls, meth, instrument(meth))
+
+ for prop in attributes:
+ if hasattr(cls, prop):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, prop)
+ )
+ setattr(cls, prop, makeprop(prop))
+
+ for prop in classmethods:
+ if hasattr(cls, prop):
+ raise TypeError(
+ "class %s already has a method %s" % (cls, prop)
+ )
+ setattr(cls, prop, instrument(prop, clslevel=True))
+
+ return cls
+
+ return decorate
def getargspec_init(method):
from sqlalchemy.testing.util import gc_collect
-class EventsTest(fixtures.TestBase):
+class TearDownLocalEventsFixture(object):
+ def tearDown(self):
+ classes = set()
+ for entry in event.base._registrars.values():
+ for evt_cls in entry:
+ if evt_cls.__module__ == __name__:
+ classes.add(evt_cls)
+
+ for evt_cls in classes:
+ event.base._remove_dispatcher(evt_cls)
+
+
+class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test class- and instance-level event registration."""
def setUp(self):
self.Target = Target
- def tearDown(self):
- event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
-
def test_register_class(self):
def listen(x, y):
pass
)
-class NamedCallTest(fixtures.TestBase):
+class SlotsEventsTest(fixtures.TestBase):
+ @testing.requires.python3
+ def test_no_slots_dispatch(self):
+ class Target(object):
+ __slots__ = ()
+
+ class TargetEvents(event.Events):
+ _dispatch_target = Target
+
+ def event_one(self, x, y):
+ pass
+
+ def event_two(self, x):
+ pass
+
+ def event_three(self, x):
+ pass
+
+ t1 = Target()
+
+ with testing.expect_raises_message(
+ TypeError,
+ r"target .*Target.* doesn't have __dict__, should it "
+ "be defining _slots_dispatch",
+ ):
+ event.listen(t1, "event_one", Mock())
+
+ def test_slots_dispatch(self):
+ class Target(object):
+ __slots__ = ("_slots_dispatch",)
+
+ class TargetEvents(event.Events):
+ _dispatch_target = Target
+
+ def event_one(self, x, y):
+ pass
+
+ def event_two(self, x):
+ pass
+
+ def event_three(self, x):
+ pass
+
+ t1 = Target()
+
+ m1 = Mock()
+ event.listen(t1, "event_one", m1)
+
+ t1.dispatch.event_one(2, 4)
+
+ eq_(m1.mock_calls, [call(2, 4)])
+
+
+class NamedCallTest(TearDownLocalEventsFixture, fixtures.TestBase):
def _fixture(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
eq_(canary.mock_calls, [call({"x": 4, "y": 5, "z": 8, "q": 5})])
-class LegacySignatureTest(fixtures.TestBase):
+class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test adaption of legacy args"""
def setUp(self):
self.TargetOne = TargetOne
- def tearDown(self):
- event.base._remove_dispatcher(
- self.TargetOne.__dict__["dispatch"].events
- )
-
def test_legacy_accept(self):
canary = Mock()
)
-class ClsLevelListenTest(fixtures.TestBase):
- def tearDown(self):
- event.base._remove_dispatcher(
- self.TargetOne.__dict__["dispatch"].events
- )
-
+class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
def setUp(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
assert handler2 not in s2.dispatch.event_one
-class AcceptTargetsTest(fixtures.TestBase):
+class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test default target acceptance."""
def setUp(self):
self.TargetOne = TargetOne
self.TargetTwo = TargetTwo
- def tearDown(self):
- event.base._remove_dispatcher(
- self.TargetOne.__dict__["dispatch"].events
- )
- event.base._remove_dispatcher(
- self.TargetTwo.__dict__["dispatch"].events
- )
-
def test_target_accept(self):
"""Test that events of the same name are routed to the correct
collection based on the type of target given.
eq_(list(t2.dispatch.event_one), [listen_two, listen_four])
-class CustomTargetsTest(fixtures.TestBase):
+class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test custom target acceptance."""
def setUp(self):
self.Target = Target
- def tearDown(self):
- event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
-
def test_indirect(self):
def listen(x, y):
pass
)
-class SubclassGrowthTest(fixtures.TestBase):
+class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test that ad-hoc subclasses are garbage collected."""
def setUp(self):
eq_(self.Target.__subclasses__(), [])
-class ListenOverrideTest(fixtures.TestBase):
+class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test custom listen functions which change the listener function
signature."""
self.Target = Target
- def tearDown(self):
- event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
-
def test_listen_override(self):
listen_one = Mock()
listen_two = Mock()
eq_(listen_one.mock_calls, [call(12)])
-class PropagateTest(fixtures.TestBase):
+class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
def setUp(self):
class TargetEvents(event.Events):
def event_one(self, arg):
eq_(listen_two.mock_calls, [])
-class JoinTest(fixtures.TestBase):
+class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
def setUp(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
self.TargetFactory = TargetFactory
self.TargetElement = TargetElement
- def tearDown(self):
- for cls in (self.TargetElement, self.TargetFactory, self.BaseTarget):
- if "dispatch" in cls.__dict__:
- event.base._remove_dispatcher(cls.__dict__["dispatch"].events)
-
def test_neither(self):
element = self.TargetFactory().create()
element.run_event(1)
)
-class DisableClsPropagateTest(fixtures.TestBase):
+class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
def setUp(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
self.BaseTarget = BaseTarget
self.SubTarget = SubTarget
- def tearDown(self):
- for cls in (self.SubTarget, self.BaseTarget):
- if "dispatch" in cls.__dict__:
- event.base._remove_dispatcher(cls.__dict__["dispatch"].events)
-
def test_listen_invoke_clslevel(self):
canary = Mock()
eq_(canary.mock_calls, [])
-class RemovalTest(fixtures.TestBase):
+class RemovalTest(TearDownLocalEventsFixture, fixtures.TestBase):
def _fixture(self):
class TargetEvents(event.Events):
def event_one(self, x, y):
class _Py3KFixtures(object):
- pass
+ def _kw_only_fixture(self):
+ pass
+
+ def _kw_plus_posn_fixture(self):
+ pass
+
+ def _kw_opt_fixture(self):
+ pass
if util.py3k:
for k in _locals:
setattr(_Py3KFixtures, k, _locals[k])
+py3k_fixtures = _Py3KFixtures()
-class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
- def _test_format_argspec_plus(self, fn, wanted, grouped=None):
-
- # test direct function
- if grouped is None:
- parsed = util.format_argspec_plus(fn)
- else:
- parsed = util.format_argspec_plus(fn, grouped=grouped)
- eq_(parsed, wanted)
-
- # test sending fullargspec
- spec = compat.inspect_getfullargspec(fn)
- if grouped is None:
- parsed = util.format_argspec_plus(spec)
- else:
- parsed = util.format_argspec_plus(spec, grouped=grouped)
- eq_(parsed, wanted)
- def test_specs(self):
- self._test_format_argspec_plus(
+class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
+ @testing.combinations(
+ (
lambda: None,
{
"args": "()",
"self_arg": None,
"apply_kw": "()",
"apply_pos": "()",
+ "apply_pos_proxied": "()",
},
- )
-
- self._test_format_argspec_plus(
+ True,
+ ),
+ (
lambda: None,
- {"args": "", "self_arg": None, "apply_kw": "", "apply_pos": ""},
- grouped=False,
- )
-
- self._test_format_argspec_plus(
+ {
+ "args": "",
+ "self_arg": None,
+ "apply_kw": "",
+ "apply_pos": "",
+ "apply_pos_proxied": "",
+ },
+ False,
+ ),
+ (
lambda self: None,
{
"args": "(self)",
"self_arg": "self",
"apply_kw": "(self)",
"apply_pos": "(self)",
+ "apply_pos_proxied": "()",
},
- )
-
- self._test_format_argspec_plus(
+ True,
+ ),
+ (
lambda self: None,
{
"args": "self",
"self_arg": "self",
"apply_kw": "self",
"apply_pos": "self",
+ "apply_pos_proxied": "",
},
- grouped=False,
- )
-
- self._test_format_argspec_plus(
+ False,
+ ),
+ (
lambda *a: None,
{
"args": "(*a)",
"self_arg": "a[0]",
"apply_kw": "(*a)",
"apply_pos": "(*a)",
+ "apply_pos_proxied": "(*a)",
},
- )
-
- self._test_format_argspec_plus(
+ True,
+ ),
+ (
lambda **kw: None,
{
"args": "(**kw)",
"self_arg": None,
"apply_kw": "(**kw)",
"apply_pos": "(**kw)",
+ "apply_pos_proxied": "(**kw)",
},
- )
-
- self._test_format_argspec_plus(
+ True,
+ ),
+ (
lambda *a, **kw: None,
{
"args": "(*a, **kw)",
"self_arg": "a[0]",
"apply_kw": "(*a, **kw)",
"apply_pos": "(*a, **kw)",
+ "apply_pos_proxied": "(*a, **kw)",
},
- )
-
- self._test_format_argspec_plus(
+ True,
+ ),
+ (
lambda a, *b: None,
{
"args": "(a, *b)",
"self_arg": "a",
"apply_kw": "(a, *b)",
"apply_pos": "(a, *b)",
+ "apply_pos_proxied": "(*b)",
},
- )
-
- self._test_format_argspec_plus(
+ True,
+ ),
+ (
lambda a, **b: None,
{
"args": "(a, **b)",
"self_arg": "a",
"apply_kw": "(a, **b)",
"apply_pos": "(a, **b)",
+ "apply_pos_proxied": "(**b)",
},
- )
-
- self._test_format_argspec_plus(
+ True,
+ ),
+ (
lambda a, *b, **c: None,
{
"args": "(a, *b, **c)",
"self_arg": "a",
"apply_kw": "(a, *b, **c)",
"apply_pos": "(a, *b, **c)",
+ "apply_pos_proxied": "(*b, **c)",
},
- )
-
- self._test_format_argspec_plus(
+ True,
+ ),
+ (
lambda a, b=1, **c: None,
{
"args": "(a, b=1, **c)",
"self_arg": "a",
"apply_kw": "(a, b=b, **c)",
"apply_pos": "(a, b, **c)",
+ "apply_pos_proxied": "(b, **c)",
},
- )
-
- self._test_format_argspec_plus(
+ True,
+ ),
+ (
lambda a=1, b=2: None,
{
"args": "(a=1, b=2)",
"self_arg": "a",
"apply_kw": "(a=a, b=b)",
"apply_pos": "(a, b)",
+ "apply_pos_proxied": "(b)",
},
- )
-
- self._test_format_argspec_plus(
+ True,
+ ),
+ (
lambda a=1, b=2: None,
{
"args": "a=1, b=2",
"self_arg": "a",
"apply_kw": "a=a, b=b",
"apply_pos": "a, b",
+ "apply_pos_proxied": "b",
},
- grouped=False,
- )
-
- if util.py3k:
- self._test_format_argspec_plus(
- self._kw_only_fixture,
- {
- "args": "self, a, *, b, c",
- "self_arg": "self",
- "apply_pos": "self, a, *, b, c",
- "apply_kw": "self, a, b=b, c=c",
- },
- grouped=False,
- )
- self._test_format_argspec_plus(
- self._kw_plus_posn_fixture,
- {
- "args": "self, a, *args, b, c",
- "self_arg": "self",
- "apply_pos": "self, a, *args, b, c",
- "apply_kw": "self, a, b=b, c=c, *args",
- },
- grouped=False,
- )
- self._test_format_argspec_plus(
- self._kw_opt_fixture,
- {
- "args": "self, a, *, b, c='c'",
- "self_arg": "self",
- "apply_pos": "self, a, *, b, c",
- "apply_kw": "self, a, b=b, c=c",
- },
- grouped=False,
- )
+ False,
+ ),
+ (
+ py3k_fixtures._kw_only_fixture,
+ {
+ "args": "self, a, *, b, c",
+ "self_arg": "self",
+ "apply_pos": "self, a, *, b, c",
+ "apply_kw": "self, a, b=b, c=c",
+ "apply_pos_proxied": "a, *, b, c",
+ },
+ False,
+ testing.requires.python3,
+ ),
+ (
+ py3k_fixtures._kw_plus_posn_fixture,
+ {
+ "args": "self, a, *args, b, c",
+ "self_arg": "self",
+ "apply_pos": "self, a, *args, b, c",
+ "apply_kw": "self, a, b=b, c=c, *args",
+ "apply_pos_proxied": "a, *args, b, c",
+ },
+ False,
+ testing.requires.python3,
+ ),
+ (
+ py3k_fixtures._kw_opt_fixture,
+ {
+ "args": "self, a, *, b, c='c'",
+ "self_arg": "self",
+ "apply_pos": "self, a, *, b, c",
+ "apply_kw": "self, a, b=b, c=c",
+ "apply_pos_proxied": "a, *, b, c",
+ },
+ False,
+ testing.requires.python3,
+ ),
+ argnames="fn,wanted,grouped",
+ )
+ def test_specs(self, fn, wanted, grouped):
+
+ # test direct function
+ if grouped is None:
+ parsed = util.format_argspec_plus(fn)
+ else:
+ parsed = util.format_argspec_plus(fn, grouped=grouped)
+ eq_(parsed, wanted)
+
+ # test sending fullargspec
+ spec = compat.inspect_getfullargspec(fn)
+ if grouped is None:
+ parsed = util.format_argspec_plus(spec)
+ else:
+ parsed = util.format_argspec_plus(spec, grouped=grouped)
+ eq_(parsed, wanted)
@testing.requires.cpython
def test_init_grouped(self):
"self_arg": "self",
"apply_pos": "(self)",
"apply_kw": "(self)",
+ "apply_pos_proxied": "()",
}
wrapper_spec = {
"args": "(self, *args, **kwargs)",
"self_arg": "self",
"apply_pos": "(self, *args, **kwargs)",
"apply_kw": "(self, *args, **kwargs)",
+ "apply_pos_proxied": "(*args, **kwargs)",
}
custom_spec = {
"args": "(slef, a=123)",
"self_arg": "slef", # yes, slef
"apply_pos": "(slef, a)",
+ "apply_pos_proxied": "(a)",
"apply_kw": "(slef, a=a)",
}
"self_arg": "self",
"apply_pos": "self",
"apply_kw": "self",
+ "apply_pos_proxied": "",
}
wrapper_spec = {
"args": "self, *args, **kwargs",
"self_arg": "self",
"apply_pos": "self, *args, **kwargs",
"apply_kw": "self, *args, **kwargs",
+ "apply_pos_proxied": "*args, **kwargs",
}
custom_spec = {
"args": "slef, a=123",
"self_arg": "slef", # yes, slef
"apply_pos": "slef, a",
"apply_kw": "slef, a=a",
+ "apply_pos_proxied": "a",
}
self._test_init(False, object_spec, wrapper_spec, custom_spec)
from sqlalchemy import Column
from sqlalchemy import delete
+from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
+from sqlalchemy import text
from sqlalchemy import union_all
from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.asyncio import engine as _async_engine
from sqlalchemy.ext.asyncio import exc as asyncio_exc
from sqlalchemy.testing import async_test
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_not
+from sqlalchemy.testing import mock
from sqlalchemy.testing.asyncio import assert_raises_message_async
+from sqlalchemy.util.concurrency import greenlet_spawn
class EngineFixture(fixtures.TablesTest):
class AsyncEngineTest(EngineFixture):
__backend__ = True
+ def test_proxied_attrs_engine(self, async_engine):
+ sync_engine = async_engine.sync_engine
+
+ is_(async_engine.url, sync_engine.url)
+ is_(async_engine.pool, sync_engine.pool)
+ is_(async_engine.dialect, sync_engine.dialect)
+ eq_(async_engine.name, sync_engine.name)
+ eq_(async_engine.driver, sync_engine.driver)
+ eq_(async_engine.echo, sync_engine.echo)
+
+ def test_clear_compiled_cache(self, async_engine):
+ async_engine.sync_engine._compiled_cache["foo"] = "bar"
+ eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")
+ async_engine.clear_compiled_cache()
+ assert "foo" not in async_engine.sync_engine._compiled_cache
+
+ def test_execution_options(self, async_engine):
+ a2 = async_engine.execution_options(foo="bar")
+ assert isinstance(a2, _async_engine.AsyncEngine)
+ eq_(a2.sync_engine._execution_options, {"foo": "bar"})
+ eq_(async_engine.sync_engine._execution_options, {})
+
+ """
+
+ attr uri, pool, dialect, engine, name, driver, echo
+ methods clear_compiled_cache, update_execution_options,
+ execution_options, get_execution_options, dispose
+
+ """
+
+ @async_test
+ async def test_proxied_attrs_connection(self, async_engine):
+ conn = await async_engine.connect()
+
+ sync_conn = conn.sync_connection
+
+ is_(conn.engine, async_engine)
+ is_(conn.closed, sync_conn.closed)
+ is_(conn.dialect, async_engine.sync_engine.dialect)
+ eq_(conn.default_isolation_level, sync_conn.default_isolation_level)
+
+ @async_test
+ async def test_invalidate(self, async_engine):
+ conn = await async_engine.connect()
+
+ is_(conn.invalidated, False)
+
+ connection_fairy = await conn.get_raw_connection()
+ is_(connection_fairy.is_valid, True)
+ dbapi_connection = connection_fairy.connection
+
+ await conn.invalidate()
+ assert dbapi_connection._connection.is_closed()
+
+ new_fairy = await conn.get_raw_connection()
+ is_not(new_fairy.connection, dbapi_connection)
+ is_not(new_fairy, connection_fairy)
+ is_(new_fairy.is_valid, True)
+ is_(connection_fairy.is_valid, False)
+
+ @async_test
+ async def test_get_dbapi_connection_raise(self, async_engine):
+
+ conn = await async_engine.connect()
+
+ with testing.expect_raises_message(
+ exc.InvalidRequestError,
+ "AsyncConnection.connection accessor is not "
+ "implemented as the attribute",
+ ):
+ conn.connection
+
+ @async_test
+ async def test_get_raw_connection(self, async_engine):
+
+ conn = await async_engine.connect()
+
+ pooled = await conn.get_raw_connection()
+ is_(pooled, conn.sync_connection.connection)
+
+ @async_test
+ async def test_isolation_level(self, async_engine):
+ conn = await async_engine.connect()
+
+ sync_isolation_level = await greenlet_spawn(
+ conn.sync_connection.get_isolation_level
+ )
+ isolation_level = await conn.get_isolation_level()
+
+ eq_(isolation_level, sync_isolation_level)
+
+ await conn.execution_options(isolation_level="SERIALIZABLE")
+ isolation_level = await conn.get_isolation_level()
+
+ eq_(isolation_level, "SERIALIZABLE")
+
+ @async_test
+ async def test_dispose(self, async_engine):
+ c1 = await async_engine.connect()
+ c2 = await async_engine.connect()
+
+ await c1.close()
+ await c2.close()
+
+ p1 = async_engine.pool
+ eq_(async_engine.pool.checkedin(), 2)
+
+ await async_engine.dispose()
+ eq_(async_engine.pool.checkedin(), 0)
+ is_not(p1, async_engine.pool)
+
@async_test
async def test_init_once_concurrency(self, async_engine):
c1 = async_engine.connect()
)
+class AsyncEventTest(EngineFixture):
+ """The engine events all run in their normal synchronous context.
+
+ we do not provide an asyncio event interface at this time.
+
+ """
+
+ __backend__ = True
+
+ @async_test
+ async def test_no_async_listeners(self, async_engine):
+ with testing.expect_raises_message(
+ NotImplementedError,
+ "asynchronous events are not implemented "
+ "at this time. Apply synchronous listeners to the "
+ "AsyncEngine.sync_engine or "
+ "AsyncConnection.sync_connection attributes.",
+ ):
+ event.listen(async_engine, "before_cursor_execute", mock.Mock())
+
+ conn = await async_engine.connect()
+
+ with testing.expect_raises_message(
+ NotImplementedError,
+ "asynchronous events are not implemented "
+ "at this time. Apply synchronous listeners to the "
+ "AsyncEngine.sync_engine or "
+ "AsyncConnection.sync_connection attributes.",
+ ):
+ event.listen(conn, "before_cursor_execute", mock.Mock())
+
+ @async_test
+ async def test_sync_before_cursor_execute_engine(self, async_engine):
+ canary = mock.Mock()
+
+ event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
+
+ async with async_engine.connect() as conn:
+ sync_conn = conn.sync_connection
+ await conn.execute(text("select 1"))
+
+ eq_(
+ canary.mock_calls,
+ [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
+ )
+
+ @async_test
+ async def test_sync_before_cursor_execute_connection(self, async_engine):
+ canary = mock.Mock()
+
+ async with async_engine.connect() as conn:
+ sync_conn = conn.sync_connection
+
+ event.listen(
+ async_engine.sync_engine, "before_cursor_execute", canary
+ )
+ await conn.execute(text("select 1"))
+
+ eq_(
+ canary.mock_calls,
+ [mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
+ )
+
+
class AsyncResultTest(EngineFixture):
@testing.combinations(
(None,), ("scalars",), ("mappings",), argnames="filter_"
+from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.testing import async_test
from sqlalchemy.testing import eq_
from sqlalchemy.testing import is_
+from sqlalchemy.testing import mock
from ...orm import _fixtures
eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
+ @async_test
+ async def test_delete(self, async_session):
+ User = self.classes.User
+
+ async with async_session.begin():
+ u1 = User(name="u1")
+
+ async_session.add(u1)
+
+ await async_session.flush()
+
+ conn = await async_session.connection()
+
+ eq_(await conn.scalar(select(func.count(User.id))), 1)
+
+ async_session.delete(u1)
+
+ await async_session.flush()
+
+ eq_(await conn.scalar(select(func.count(User.id))), 0)
+
@async_test
async def test_flush(self, async_session):
User = self.classes.User
is_(new_u_merged, u1)
eq_(u1.name, "new u1")
+
+
+class AsyncEventTest(AsyncFixture):
+ """The engine events all run in their normal synchronous context.
+
+ we do not provide an asyncio event interface at this time.
+
+ """
+
+ __backend__ = True
+
+ @async_test
+ async def test_no_async_listeners(self, async_session):
+ with testing.expect_raises(
+ NotImplementedError,
+ "NotImplementedError: asynchronous events are not implemented "
+ "at this time. Apply synchronous listeners to the "
+ "AsyncEngine.sync_engine or "
+ "AsyncConnection.sync_connection attributes.",
+ ):
+ event.listen(async_session, "before_flush", mock.Mock())
+
+ @async_test
+ async def test_sync_before_commit(self, async_session):
+ canary = mock.Mock()
+
+ event.listen(async_session.sync_session, "before_commit", canary)
+
+ async with async_session.begin():
+ pass
+
+ eq_(
+ canary.mock_calls,
+ [mock.call(async_session.sync_session)],
+ )
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import mock
from sqlalchemy.testing.mock import Mock
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
mock_scope_func.return_value = 1
s2 = Session(autocommit=True)
assert s2.autocommit == True
+
+ def test_methods_etc(self):
+ mock_session = Mock()
+ mock_session.bind = "the bind"
+
+ sess = scoped_session(lambda: mock_session)
+
+ sess.add("add")
+ sess.delete("delete")
+
+ eq_(sess.bind, "the bind")
+
+ eq_(
+ mock_session.mock_calls,
+ [mock.call.add("add", True), mock.call.delete("delete")],
+ )
+
+ with mock.patch(
+ "sqlalchemy.orm.session.object_session"
+ ) as mock_object_session:
+ sess.object_session("foo")
+
+ eq_(mock_object_session.mock_calls, [mock.call("foo")])
+import inspect as _py_inspect
+
import sqlalchemy as sa
from sqlalchemy import event
from sqlalchemy import ForeignKey
class SessionInterface(fixtures.TestBase):
"""Bogus args to Session methods produce actionable exceptions."""
- # TODO: expand with message body assertions.
-
_class_methods = set(("connection", "execute", "get_bind", "scalar"))
def _public_session_methods(self):
Session = sa.orm.session.Session
- blacklist = set(("begin", "query"))
-
+ blacklist = {"begin", "query", "bind_mapper", "get", "bind_table"}
+ specials = {"__iter__", "__contains__"}
ok = set()
- for meth in Session.public_methods:
- if meth in blacklist:
- continue
- spec = inspect_getfullargspec(getattr(Session, meth))
- if len(spec[0]) > 1 or spec[1]:
- ok.add(meth)
+ for name in dir(Session):
+ if (
+ name in Session.__dict__
+ and (not name.startswith("_") or name in specials)
+ and (
+ _py_inspect.ismethod(getattr(Session, name))
+ or _py_inspect.isfunction(getattr(Session, name))
+ )
+ ):
+ if name in blacklist:
+ continue
+ spec = inspect_getfullargspec(getattr(Session, name))
+ if len(spec[0]) > 1 or spec[1]:
+ ok.add(name)
return ok
def _map_it(self, cls):
def raises_(method, *args, **kw):
x_raises_(create_session(), method, *args, **kw)
- raises_("__contains__", user_arg)
-
- raises_("add", user_arg)
+ for name in [
+ "__contains__",
+ "is_modified",
+ "merge",
+ "refresh",
+ "add",
+ "delete",
+ "expire",
+ "expunge",
+ "enable_relationship_loading",
+ ]:
+ raises_(name, user_arg)
raises_("add_all", (user_arg,))
- raises_("delete", user_arg)
-
- raises_("expire", user_arg)
-
- raises_("expunge", user_arg)
-
# flush will no-op without something in the unit of work
def _():
class OK(object):
_()
- raises_("is_modified", user_arg)
-
- raises_("merge", user_arg)
-
- raises_("refresh", user_arg)
-
instance_methods = (
self._public_session_methods()
- self._class_methods