]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- unrolled loops for the simplified Session.get_bind() args
authorJason Kirtland <jek@discorporate.us>
Wed, 21 May 2008 23:58:16 +0000 (23:58 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 21 May 2008 23:58:16 +0000 (23:58 +0000)
- restored the chunk of test r4806 deleted (!)

lib/sqlalchemy/orm/session.py
test/orm/session.py

index 36d71a763f33aa850419b593ce8d2dd2a4bd2768..5d6326ac407fec62b9743db3de7bf99aaf22a834 100644 (file)
@@ -685,9 +685,6 @@ class Session(object):
         clause
           Optional, any ``ClauseElement``
 
-        instance
-          Optional, an instance of a mapped class
-
         """
         return self.__connection(self.get_bind(mapper, clause, _state))
 
@@ -834,7 +831,7 @@ class Session(object):
 
         _state
           Optional, SA internal representation of a mapped instance
-            
+
         """
         if mapper is clause is _state is None:
             if self.bind:
@@ -845,19 +842,21 @@ class Session(object):
                     "Connection, and no context was provided to locate "
                     "a binding.")
 
-        mappers = []
-        if _state is not None:
-            mappers.append(_state_mapper(_state))
-        if mapper is not None:
-            mappers.append(_class_to_mapper(mapper))
+        s_mapper = _state is not None and _state_mapper(_state) or None
+        c_mapper = mapper is not None and _class_to_mapper(mapper) or None
 
         # manually bound?
         if self.__binds:
-            for m in mappers:
-                if m.base_mapper in self.__binds:
-                    return self.__binds[m.base_mapper]
-                elif m.mapped_table in self.__binds:
-                    return self.__binds[m.mapped_table]
+            if s_mapper:
+                if s_mapper.base_mapper in self.__binds:
+                    return self.__binds[s_mapper.base_mapper]
+                elif s_mapper.mapped_table in self.__binds:
+                    return self.__binds[s_mapper.mapped_table]
+            if c_mapper:
+                if c_mapper.base_mapper in self.__binds:
+                    return self.__binds[c_mapper.base_mapper]
+                elif c_mapper.mapped_table in self.__binds:
+                    return self.__binds[c_mapper.mapped_table]
             if clause:
                 for t in sql_util.find_tables(clause):
                     if t in self.__binds:
@@ -868,13 +867,14 @@ class Session(object):
         if isinstance(clause, sql.expression.ClauseElement) and clause.bind:
             return clause.bind
 
-        for m in mappers:
-            if m.mapped_table.bind:
-                return m.mapped_table.bind
+        if s_mapper and s_mapper.mapped_table.bind:
+            return s_mapper.mapped_table.bind
+        if c_mapper and c_mapper.mapped_table.bind:
+            return c_mapper.mapped_table.bind
 
         context = []
         if mapper is not None:
-            context.append('mapper %s' % _class_to_mapper(mapper))
+            context.append('mapper %s' % c_mapper)
         if clause is not None:
             context.append('SQL expression')
         if _state is not None:
index 5e45afb45f635e37f99cdefaaecbf5b42119f6aa..7f418122786781ff5b2b85ceab1cab04eb0e7c6c 100644 (file)
@@ -986,12 +986,13 @@ class SessionInterface(testing.TestBase):
 
     # TODO: expand with message body assertions.
 
-    _class_methods = set(('get', 'load'))
+    _class_methods = set((
+        'connection', 'execute', 'get', 'get_bind', 'load', 'scalar'))
 
     def _public_session_methods(self):
         Session = sa.orm.session.Session
 
-        blacklist = set(('begin', 'query', 'connection', 'execute', 'get_bind', 'scalar'))
+        blacklist = set(('begin', 'query'))
 
         ok = set()
         for meth in Session.public_methods:
@@ -1067,10 +1068,18 @@ class SessionInterface(testing.TestBase):
             self.assertRaises(sa.orm.exc.UnmappedClassError,
                               callable_, *args, **kw)
 
+        raises_('connection', mapper=user_arg)
+
+        raises_('execute', 'SELECT 1', mapper=user_arg)
+
         raises_('get', user_arg, 1)
 
+        raises_('get_bind', mapper=user_arg)
+
         raises_('load', user_arg, 1)
 
+        raises_('scalar', 'SELECT 1', mapper=user_arg)
+
         eq_(watchdog, self._class_methods,
             watchdog.symmetric_difference(self._class_methods))