]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- factored some fixes from trunk to lazyloader use_get, logging
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 May 2007 19:35:50 +0000 (19:35 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 20 May 2007 19:35:50 +0000 (19:35 +0000)
- deferred inheritance loading: polymorphic mappers can be constructed *without*
a select_table argument.  inheriting mappers whose tables were not
represented in the initial load will issue a second SQL query immediately,
once per instance (i.e. not very efficient for large lists),
in order to load the remaining columns.

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py
test/orm/inheritance/polymorph.py
test/orm/mapper.py

diff --git a/CHANGES b/CHANGES
index 3dab91e901f223692df3301001607a59cd8b4e8f..0730a5c27ce0977722c8b35e10764305d28a7815 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,13 @@
+0.4.0
+
+- orm
+    - deferred inheritance loading: polymorphic mappers can be constructed *without* 
+      a select_table argument.  inheriting mappers whose tables were not 
+      represented in the initial load will issue a second SQL query immediately,
+      once per instance (i.e. not very efficient for large lists), 
+      in order to load the remaining columns.
+    
+0.3.XXX
 - engines
     - added detach() to Connection, allows underlying DBAPI connection to be detached 
       from its pool, closing on dereference/close() instead of being reused by the pool.
index 7e44d8a42eaf0234a0f867c9688b7cb8d47ba62d..91e99994433ec8b0ef56d0425e90d25cad47d3be 100644 (file)
@@ -1403,9 +1403,11 @@ class Mapper(object):
             if discriminator is not None:
                 mapper = self.polymorphic_map[discriminator]
                 if mapper is not self:
+                    if ('needsload', mapper) not in context.attributes:
+                        context.attributes[('needsload', mapper)] = (self, [t for t in mapper.tables if t not in self.tables])
                     row = self.translate_row(mapper, row)
                     return mapper._instance(context, row, result=result, skip_polymorphic=True)
-
+                    
         # look in main identity map.  if its there, we dont do anything to it,
         # including modifying any of its related items lists, as its already
         # been exposed to being modified by the application.
@@ -1484,6 +1486,43 @@ class Mapper(object):
 
         return obj
 
+    def _post_instance(self, context, instance):
+        (hosted_mapper, needs_tables) = context.attributes.get(('needsload', self), (None, None))
+        if needs_tables is None or len(needs_tables) == 0:
+            return
+        
+        self.__log_debug("Post query loading instance " + mapperutil.instance_str(instance))
+        if ('post_select', self) not in context.attributes:
+            cond = self.inherit_condition.copy_container()
+
+            param_names = []
+            def visit_binary(binary):
+                leftcol = binary.left
+                rightcol = binary.right
+                if leftcol is None or rightcol is None:
+                    return
+                if leftcol.table not in needs_tables:
+                    binary.left = sql.bindparam(leftcol.name, None, type=binary.right.type, unique=True)
+                    param_names.append(leftcol)
+                elif rightcol not in needs_tables:
+                    binary.right = sql.bindparam(rightcol.name, None, type=binary.right.type, unique=True)
+                    param_names.append(rightcol)
+            mapperutil.BinaryVisitor(visit_binary).traverse(cond)
+            statement = sql.select(needs_tables, cond)
+            context.attributes[('post_select', self)] = (statement, param_names)
+            
+        (statement, binds) = context.attributes.get(('post_select', self))
+        
+        identitykey = self.instance_key(instance)
+        
+        params = {}
+        for c in binds:
+            params[c.name] = self.get_attr_by_column(instance, c)
+        row = context.session.connection(self).execute(statement, **params).fetchone()
+        for prop in self.__props.values():
+            if prop.parent is not hosted_mapper:
+                prop.execute(context, instance, row, identitykey, True)
+
     def translate_row(self, tomapper, row):
         """Translate the column keys of a row into a new or proxied
         row that can be understood by another mapper.
@@ -1494,8 +1533,8 @@ class Mapper(object):
 
         newrow = util.DictDecorator(row)
         for c in tomapper.mapped_table.c:
-            c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=True)
-            if row.has_key(c2):
+            c2 = self.mapped_table.corresponding_column(c, keys_ok=True, raiseerr=False)
+            if c2 and row.has_key(c2):
                 newrow[c] = row[c2]
         return newrow
 
@@ -1504,6 +1543,7 @@ class Mapper(object):
 
         This method iterates through the list of MapperProperty objects attached to this Mapper
         and calls each properties execute() method."""
+        
         for prop in self.__props.values():
             prop.execute(selectcontext, instance, row, identitykey, isnew)
 
index 38279f5f210b03272a7ade4901036cfeeb503ae4..a87fa8d19a227faaef1ac3db53a145dfb1e37ecd 100644 (file)
@@ -882,16 +882,20 @@ class Query(object):
                 if isinstance(m, type):
                     m = mapper.class_mapper(m)
                 if isinstance(m, mapper.Mapper):
-                    appender = []
-                    def proc(context, row):
-                        if not m._instance(context, row, appender):
-                            appender.append(None)
-                    process.append((proc, appender))
+                    def x(m):
+                        appender = []
+                        def proc(context, row):
+                            if not m._instance(context, row, appender):
+                                appender.append(None)
+                        process.append((proc, appender))
+                    x(m)
                 elif isinstance(m, sql.ColumnElement) or isinstance(m, basestring):
-                    res = []
-                    def proc(context, row):
-                        res.append(row[m])
-                    process.append((proc, res))
+                    def y(m):
+                        res = []
+                        def proc(context, row):
+                            res.append(row[m])
+                        process.append((proc, res))
+                    y(m)
             result = []
         else:
             result = util.UniqueAppender([])
@@ -901,6 +905,9 @@ class Query(object):
             for proc in process:
                 proc[0](context, row)
 
+        for value in context.identity_map.values():
+            object_mapper(value)._post_instance(context, value)
+        
         # store new stuff in the identity map
         for value in context.identity_map.values():
             session._register_persistent(value)
index ddf7d6251c8bb1af496178a65f1c1d6b4fae6858..9f0f68f5dd4be033aa3d896fb5cd06718a653f26 100644 (file)
@@ -593,12 +593,12 @@ class Session(object):
     def _attach(self, obj):
         """Attach the given object to this ``Session``."""
 
-        if getattr(obj, '_sa_session_id', None) != self.hash_key:
-            old = getattr(obj, '_sa_session_id', None)
-            if old is not None and _sessions.has_key(old):
+        old_id = getattr(obj, '_sa_session_id', None)
+        if old_id != self.hash_key:
+            if old_id is not None and _sessions.has_key(old_id):
                 raise exceptions.InvalidRequestError("Object '%s' is already attached "
                                                      "to session '%s' (this is '%s')" %
-                                                     (repr(obj), old, id(self)))
+                                                     (repr(obj), old_id, id(self)))
 
                 # auto-removal from the old session is disabled.  but if we decide to
                 # turn it back on, do it as below: gingerly since _sessions is a WeakValueDict
index f1b159318f374fa5567842fed5e22c82a5e1c3b4..1eeb77735b35d941aba7eee3069daa5ce36cf846 100644 (file)
@@ -36,7 +36,10 @@ class ColumnLoader(LoaderStrategy):
         if isnew:
             if self._should_log_debug:
                 self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
-            instance.__dict__[self.key] = row[self.columns[0]]
+            try:
+                instance.__dict__[self.key] = row[self.columns[0]]
+            except KeyError:
+                pass
         
 ColumnLoader.logger = logging.class_logger(ColumnLoader)
 
@@ -162,10 +165,15 @@ class LazyLoader(AbstractRelationLoader):
     def init(self):
         super(LazyLoader, self).init()
         (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self)
+        
+        self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.lazywhere))
 
         # determine if our "lazywhere" clause is the same as the mapper's
         # get() clause.  then we can just use mapper.get()
         self.use_get = not self.uselist and query.Query(self.mapper)._get_clause.compare(self.lazywhere)
+        if self.use_get:
+            self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads")
+
 
     def init_class_attribute(self):
         self._register_attribute(self.parent.class_, callable_=lambda i: self.setup_loader(i))
@@ -303,8 +311,6 @@ class LazyLoader(AbstractRelationLoader):
                 li.traverse(secondaryjoin)
             lazywhere = sql.and_(lazywhere, secondaryjoin)
  
-        if hasattr(cls, 'parent_property'):
-            LazyLoader.logger.info(str(cls.parent_property) + " lazy loading clause " + str(lazywhere))
         return (lazywhere, binds, reverse)
     _create_lazy_clause = classmethod(_create_lazy_clause)
     
index 35cd30b30415d807d170c1c5358724240c3fe2d7..9a93474dc7b6d63f3f89e3447c44f14978b1ef6f 100644 (file)
@@ -1964,7 +1964,9 @@ class ClauseList(ClauseElement):
         including a comparison of all the clause items.
         """
 
-        if isinstance(other, ClauseList) and len(self.clauses) == len(other.clauses):
+        if not isinstance(other, ClauseList) and len(self.clauses) == 1:
+            return self.clauses[0].compare(other)
+        elif isinstance(other, ClauseList) and len(self.clauses) == len(other.clauses):
             for i in range(0, len(self.clauses)):
                 if not self.clauses[i].compare(other.clauses[i]):
                     return False
index ea5a468d2afdd3a0101a615622af5e48742cd707..af1bbe27036c7da25bb5933a0255ec7c53415933 100644 (file)
@@ -426,7 +426,10 @@ class UniqueAppender(object):
         if item not in self.set:
             self.set.add(item)
             self._data_appender(item)
-
+    
+    def __iter__(self):
+        return iter(self.data)
+        
 class ScopedRegistry(object):
     """A Registry that can store one or multiple instances of a single
     class on a per-thread scoped basis, or on a customized scope.
index 9d886cf3f1f86dc3e315550a715c5406276cf5d6..107203e472f665974cbe391b8236db7dc42454d8 100644 (file)
@@ -192,7 +192,7 @@ class RelationToSubclassTest(PolymorphTest):
         assert sets.Set([e.get_name() for e in c.managers]) == sets.Set(['pointy haired boss'])
         assert c.managers[0].company is c
         
-def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False):
+def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False, use_union=False):
     """generates a round trip test.
     
     include_base - whether or not to include the base 'person' type in the union.
@@ -203,7 +203,9 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
     class RoundTripTest(PolymorphTest):
         def test_roundtrip(self):
             # create a union that represents both types of joins.  
-            if include_base:
+            if not use_union:
+                person_join = None
+            elif include_base:
                 person_join = polymorphic_union(
                     {
                         'engineer':people.join(engineers),
@@ -218,9 +220,9 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
                     }, None, 'pjoin')
 
             if redefine_colprop:
-                person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
+                person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
             else:
-                person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=person_join.c.type, polymorphic_identity='person')
+                person_mapper = mapper(Person, people, select_table=person_join, polymorphic_on=people.c.type, polymorphic_identity='person')
             
             mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
             mapper(Manager, managers, inherits=person_mapper, polymorphic_identity='manager')
@@ -260,9 +262,9 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
             for e in c.employees:
                 print e, e._instance_key, e.company
             if include_base:
-                assert sets.Set([e.get_name() for e in c.employees]) == sets.Set(['pointy haired boss', 'dilbert', 'joesmith', 'wally', 'jsmith'])
+                assert sets.Set([(e.get_name(), getattr(e, 'status', None)) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('joesmith', None), ('wally', 'CGG'), ('jsmith', 'ABA')])
             else:
-                assert sets.Set([e.get_name() for e in c.employees]) == sets.Set(['pointy haired boss', 'dilbert', 'wally', 'jsmith'])
+                assert sets.Set([(e.get_name(), e.status) for e in c.employees]) == sets.Set([('pointy haired boss', 'AAB'), ('dilbert', 'BBA'), ('wally', 'CGG'), ('jsmith', 'ABA')])
             print "\n"
 
         
@@ -300,7 +302,7 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
         (lazy_relation and "Lazy" or "Eager"),
         (include_base and "Inclbase" or ""),
         (redefine_colprop and "Redefcol" or ""),
-        (use_literal_join and "Litjoin" or "")
+        (not use_union and "Nounion" or (use_literal_join and "Litjoin" or ""))
     )
     return RoundTripTest
 
@@ -308,8 +310,9 @@ for include_base in [True, False]:
     for lazy_relation in [True, False]:
         for redefine_colprop in [True, False]:
             for use_literal_join in [True, False]:
-                testclass = generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join)
-                exec("%s = testclass" % testclass.__name__)
+                for use_union in [True, False]:
+                    testclass = generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, use_union)
+                    exec("%s = testclass" % testclass.__name__)
                 
 if __name__ == "__main__":    
     testbase.main()
index 4ce1be2fa762aeb56d15784137a979787d86f8ef..889a7c925972c896ae04a8224cf2e87aac9a111c 100644 (file)
@@ -1087,6 +1087,22 @@ class LazyTest(MapperSuperTest):
         self.echo(repr(l[0].user))
         self.assert_(l[0].user is not None)
 
+    def testuseget(self):
+        """test that a simple many-to-one lazyload optimizes to use query.get().
+        
+        this is done currently by comparing the 'get' SQL clause of the query
+        to the 'lazy' SQL clause of the lazy loader, so it relies heavily on 
+        ClauseElement.compare()"""
+        
+        m = mapper(Address, addresses, properties = dict(
+            user = relation(mapper(User, users), lazy = True)
+        ))
+        sess = create_session()
+        a1 = sess.query(Address).get_by(email_address = "ed@wood.com")
+        u1 = sess.query(User).get(8)
+        def go():
+            assert a1.user is u1
+        self.assert_sql_count(db, go, 0)
 
     def testdouble(self):
         """tests lazy loading with two relations simulatneously, from the same table, using aliases.  """
@@ -1619,6 +1635,20 @@ class InstancesTest(MapperSuperTest):
             (user8, 3),
             (user9, 0)
         ]
+        
+    def testmappersplustwocolumns(self):
+        mapper(User, users)
+        s = select([users, func.count(addresses.c.address_id).label('count'), ("Name:" + users.c.user_name).label('concat')], from_obj=[users.outerjoin(addresses)], group_by=[c for c in users.c], order_by=[users.c.user_id])
+        sess = create_session()
+        (user7, user8, user9) = sess.query(User).select()
+        q = sess.query(User)
+        l = q.instances(s.execute(), "count", "concat")
+        print l
+        assert l == [
+            (user7, 1, "Name:jack"),
+            (user8, 3, "Name:ed"),
+            (user9, 0, "Name:fred")
+        ]
 
 
 if __name__ == "__main__":