From eb8a6ed51a7d23146c87823d4aeb186d33990fe5 Mon Sep 17 00:00:00 2001 From: Jason Kirtland Date: Wed, 21 May 2008 23:58:16 +0000 Subject: [PATCH] - unrolled loops for the simplified Session.get_bind() args - restored the chunk of test r4806 deleted (!) --- lib/sqlalchemy/orm/session.py | 36 +++++++++++++++++------------------ test/orm/session.py | 13 +++++++++++-- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 36d71a763f..5d6326ac40 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -685,9 +685,6 @@ class Session(object): clause Optional, any ``ClauseElement`` - instance - Optional, an instance of a mapped class - """ return self.__connection(self.get_bind(mapper, clause, _state)) @@ -834,7 +831,7 @@ class Session(object): _state Optional, SA internal representation of a mapped instance - + """ if mapper is clause is _state is None: if self.bind: @@ -845,19 +842,21 @@ class Session(object): "Connection, and no context was provided to locate " "a binding.") - mappers = [] - if _state is not None: - mappers.append(_state_mapper(_state)) - if mapper is not None: - mappers.append(_class_to_mapper(mapper)) + s_mapper = _state is not None and _state_mapper(_state) or None + c_mapper = mapper is not None and _class_to_mapper(mapper) or None # manually bound? if self.__binds: - for m in mappers: - if m.base_mapper in self.__binds: - return self.__binds[m.base_mapper] - elif m.mapped_table in self.__binds: - return self.__binds[m.mapped_table] + if s_mapper: + if s_mapper.base_mapper in self.__binds: + return self.__binds[s_mapper.base_mapper] + elif s_mapper.mapped_table in self.__binds: + return self.__binds[s_mapper.mapped_table] + if c_mapper: + if c_mapper.base_mapper in self.__binds: + return self.__binds[c_mapper.base_mapper] + elif c_mapper.mapped_table in self.__binds: + return self.__binds[c_mapper.mapped_table] if clause: for t in sql_util.find_tables(clause): if t in self.__binds: @@ -868,13 +867,14 @@ class Session(object): if isinstance(clause, sql.expression.ClauseElement) and clause.bind: return clause.bind - for m in mappers: - if m.mapped_table.bind: - return m.mapped_table.bind + if s_mapper and s_mapper.mapped_table.bind: + return s_mapper.mapped_table.bind + if c_mapper and c_mapper.mapped_table.bind: + return c_mapper.mapped_table.bind context = [] if mapper is not None: - context.append('mapper %s' % _class_to_mapper(mapper)) + context.append('mapper %s' % c_mapper) if clause is not None: context.append('SQL expression') if _state is not None: diff --git a/test/orm/session.py b/test/orm/session.py index 5e45afb45f..7f41812278 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -986,12 +986,13 @@ class SessionInterface(testing.TestBase): # TODO: expand with message body assertions. - _class_methods = set(('get', 'load')) + _class_methods = set(( + 'connection', 'execute', 'get', 'get_bind', 'load', 'scalar')) def _public_session_methods(self): Session = sa.orm.session.Session - blacklist = set(('begin', 'query', 'connection', 'execute', 'get_bind', 'scalar')) + blacklist = set(('begin', 'query')) ok = set() for meth in Session.public_methods: @@ -1067,10 +1068,18 @@ class SessionInterface(testing.TestBase): self.assertRaises(sa.orm.exc.UnmappedClassError, callable_, *args, **kw) + raises_('connection', mapper=user_arg) + + raises_('execute', 'SELECT 1', mapper=user_arg) + raises_('get', user_arg, 1) + raises_('get_bind', mapper=user_arg) + raises_('load', user_arg, 1) + raises_('scalar', 'SELECT 1', mapper=user_arg) + eq_(watchdog, self._class_methods, watchdog.symmetric_difference(self._class_methods)) -- 2.47.3