]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2005 00:45:36 +0000 (00:45 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2005 00:45:36 +0000 (00:45 +0000)
lib/sqlalchemy/mapper.py
lib/sqlalchemy/sql.py
test/mapper.py

index 36f131745373fd2f84fd51cb077c3090eea7a2d2..99bd7d455ef6f48767ed7dddaecbcf0c06b7ef76 100644 (file)
@@ -59,18 +59,19 @@ class Mapper(object):
     def instances(self, cursor):
         result = []
         cursor = ResultProxy(cursor)
-        lastinstance = None
+        localmap = IdentityMap()
         while True:
             row = cursor.fetchone()
             if row is None:
                 break
-            instance = self._instance(row)
-            if instance != lastinstance:
+                
+            identitykey = localmap.get_key(row, self.class_, self.table)
+            if not localmap.map.has_key(identitykey):
+                instance = self._create(row, identitykey, localmap)
                 result.append(instance)
-                lastinstance = instance
             else:
                 for key, prop in self.props.iteritems():
-                    prop.execute(instance, key, row, True)
+                    prop.execute(instance, key, row, identitykey, localmap, True)
                 
         return result
         
@@ -110,13 +111,18 @@ class Mapper(object):
         statement.use_labels = True
         return self.instances(statement.execute(**params))
 
-    def _instance(self, row):
-        return self.identitymap.get(row, self.class_, self.table, creator = self._create)
+    def _identity_key(self, row):
+        return self.identitymap.get_key(row, self.class_, self.table)
 
-    def _create(self, row):
+    def _create(self, row, identitykey, localmap):
         instance = self.class_()
+        for column in self.table.primary_keys:
+            if row[column.label] is None:
+                return None
         for key, prop in self.props.iteritems():
-            prop.execute(instance, key, row, False)
+            prop.execute(instance, key, row, identitykey, localmap, False)
+        self.identitymap.map[identitykey] = instance
+        localmap.map[identitykey] = instance
         return instance
 
 
@@ -130,7 +136,7 @@ class ColumnProperty(MapperProperty):
     def __init__(self, column):
         self.column = column
         
-    def execute(self, instance, key, row, isduplicate):
+    def execute(self, instance, key, row, identitykey, localmap, isduplicate):
         if not isduplicate:
             setattr(instance, key, row[self.column.label])
 
@@ -138,6 +144,7 @@ class EagerLoader(MapperProperty):
     def __init__(self, mapper, whereclause):
         self.mapper = mapper
         self.whereclause = whereclause
+        
     def setup(self, key, primarytable, statement):
         targettable = self.mapper.table
         if hasattr(statement, '_outerjoin'):
@@ -146,14 +153,17 @@ class EagerLoader(MapperProperty):
             statement._outerjoin = sql.outerjoin(primarytable, targettable, self.whereclause)
         statement.append_from(statement._outerjoin)
         statement.append_column(targettable)
-    def execute(self, instance, key, row, isduplicate):
+        
+    def execute(self, instance, key, row, identitykey, localmap, isduplicate):
         try:
             list = getattr(instance, key)
         except AttributeError:
             list = []
             setattr(instance, key, list)
-        subinstance = self.mapper._instance(row)
-        if subinstance is not None:
+        
+        identitykey = self.mapper._identity_key(row)
+        if not localmap.has_key(identitykey):
+            subinstance = self.mapper._create(row, identitykey, localmap)
             list.append(subinstance)
 
 class ResultProxy:
@@ -186,22 +196,23 @@ class IdentityMap(object):
     def __init__(self):
         self.map = {}
         self.keystereotypes = {}
+    
+    def has_key(self, key):
+        return self.map.has_key(key)
+        
+    def get_key(self, row, class_, table):
+        return (class_, table.id, tuple([row[column.label] for column in table.primary_keys]))
         
-    def get(self, row, class_, table, creator = None):
+    def get(self, row, class_, table, key = None):
         """given a database row, a class to be instantiated, and a table corresponding 
         to the row, returns a corrseponding object instance, if any, from the identity
         map.  the primary keys specified in the table will be used to indicate which
         columns from the row form the effective key of the instance."""
-        key = (class_, table, tuple([row[column.label] for column in table.primary_keys]))
         
-        try:
-            return self.map[key]
-        except KeyError:
-            newinstance = creator(row)
-            for column in table.primary_keys:
-                if row[column.label] is None:
-                    return None
-            return self.map.setdefault(key, newinstance)
+        if key is None:
+            key = self.get_key(row, class_, table)
+
+        return self.map[key]
             
     
     
index 071a6185ce5708bd549e0bf43f0f86059e8048f9..7438408d69fef28c20e874d25821ec1cf83dedf4 100644 (file)
@@ -296,7 +296,6 @@ class Join(Selectable):
         
     def _get_from_objects(self):
         result = [self] + [FromClause(from_key = c.id) for c in self.left._get_from_objects() + self.right._get_from_objects()]
-        print repr([c.id for c in result])
         return result
         
 class Alias(Selectable):
index 23d01baac4fbf509669bcf73c34065417d29997c..3d0ee0b8acdc31a1c095898d03ecd2352f36223c 100644 (file)
@@ -12,9 +12,17 @@ import sqlalchemy.mapper as mapper
 
 class User:
     def __repr__(self):
-        return ("User: " + repr(self.user_id) + " " + self.user_name + repr(getattr(self, 'addresses', None)) +
-            repr(getattr(self, 'orders', None))
-            )
+        return (
+"""
+User ID: %s
+Addresses: %s
+Orders: %s
+Open Orders %s
+Closed Orders %s
+------------------
+""" % tuple([self.user_id] + [repr(getattr(self, attr, None)) for attr in ('addresses', 'orders', 'orders_open', 'orders_closed')])
+)
+
             
 class Address:
     def __repr__(self):
@@ -67,7 +75,6 @@ class MapperTest(PersistTest):
         m = mapper.Mapper(User, self.users)
         l = m.select()
         print repr(l)
-        print repr(m.identitymap.map)
 
     def testeager(self):
         m = mapper.Mapper(User, self.users, properties = dict(
@@ -77,12 +84,19 @@ class MapperTest(PersistTest):
         print repr(l)
 
     def testmultieager(self):
+        m = mapper.Mapper(User, self.users, properties = dict(
+            addresses = mapper.EagerLoader(mapper.Mapper(Address, self.addresses), self.users.c.user_id==self.addresses.c.user_id),
+            orders = mapper.EagerLoader(mapper.Mapper(Order, self.orders), and_(self.orders.c.isopen == 1, self.users.c.user_id==self.orders.c.user_id)),
+        ), identitymap = mapper.IdentityMap())
+        l = m.select()
+        print repr(l)
+#        return
         openorders = alias(self.orders, 'openorders')
         closedorders = alias(self.orders, 'closedorders')
         m = mapper.Mapper(User, self.users, properties = dict(
             orders_open = mapper.EagerLoader(mapper.Mapper(Order, openorders), and_(openorders.c.isopen == 1, self.users.c.user_id==openorders.c.user_id)),
             orders_closed = mapper.EagerLoader(mapper.Mapper(Order, closedorders), and_(closedorders.c.isopen == 0, self.users.c.user_id==closedorders.c.user_id))
-        ))
+        ), identitymap = mapper.IdentityMap())
         l = m.select()
         print repr(l)