]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2005 01:02:47 +0000 (01:02 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jul 2005 01:02:47 +0000 (01:02 +0000)
lib/sqlalchemy/mapper.py
test/mapper.py

index 99bd7d455ef6f48767ed7dddaecbcf0c06b7ef76..2a178b56b6b10de3acf125d8e06882d187150a3f 100644 (file)
@@ -38,13 +38,15 @@ import sqlalchemy.sql as sql
 import sqlalchemy.schema as schema
 
 class Mapper(object):
-    def __init__(self, class_, table, properties = None, identitymap = None):
+    def __init__(self, class_, selectable, properties = None, identitymap = None):
         self.class_ = class_
-        self.table = table
+
+        self.selectable = selectable
+        self.table = self._find_table(selectable)
         
         self.props = {}
         
-        for column in table.columns:
+        for column in self.selectable.columns:
             self.props[column.key] = ColumnProperty(column)
 
         if properties is not None:
@@ -65,7 +67,7 @@ class Mapper(object):
             if row is None:
                 break
                 
-            identitykey = localmap.get_key(row, self.class_, self.table)
+            identitykey = localmap.get_key(row, self.class_, self.table, self.selectable)
             if not localmap.map.has_key(identitykey):
                 instance = self._create(row, identitykey, localmap)
                 result.append(instance)
@@ -100,11 +102,19 @@ class Mapper(object):
     def delete(self, whereclause = None, **params):
         pass
 
-
+    class TableFinder(sql.ClauseVisitor):
+        def visit_table(self, table):
+            self.table = table
+            
+    def _find_table(self, selectable):
+        tf = Mapper.TableFinder()
+        selectable.accept_visitor(tf)
+        return tf.table
+        
     def _select_whereclause(self, whereclause = None, **params):
-        statement = sql.select([self.table], whereclause)
+        statement = sql.select([self.selectable], whereclause)
         for key, value in self.props.iteritems():
-            value.setup(key, self.table, statement) 
+            value.setup(key, self.selectable, statement) 
         return self._select_statement(statement, **params)
     
     def _select_statement(self, statement, **params):
@@ -112,11 +122,11 @@ class Mapper(object):
         return self.instances(statement.execute(**params))
 
     def _identity_key(self, row):
-        return self.identitymap.get_key(row, self.class_, self.table)
+        return self.identitymap.get_key(row, self.class_, self.table, self.selectable)
 
     def _create(self, row, identitykey, localmap):
         instance = self.class_()
-        for column in self.table.primary_keys:
+        for column in self.selectable.primary_keys:
             if row[column.label] is None:
                 return None
         for key, prop in self.props.iteritems():
@@ -146,7 +156,7 @@ class EagerLoader(MapperProperty):
         self.whereclause = whereclause
         
     def setup(self, key, primarytable, statement):
-        targettable = self.mapper.table
+        targettable = self.mapper.selectable
         if hasattr(statement, '_outerjoin'):
             statement._outerjoin = sql.outerjoin(statement._outerjoin, targettable, self.whereclause)
         else:
@@ -200,18 +210,10 @@ class IdentityMap(object):
     def has_key(self, key):
         return self.map.has_key(key)
         
-    def get_key(self, row, class_, table):
-        return (class_, table.id, tuple([row[column.label] for column in table.primary_keys]))
-        
-    def get(self, row, class_, table, key = None):
-        """given a database row, a class to be instantiated, and a table corresponding 
-        to the row, returns a corrseponding object instance, if any, from the identity
-        map.  the primary keys specified in the table will be used to indicate which
-        columns from the row form the effective key of the instance."""
+    def get_key(self, row, class_, table, selectable):
+        return (class_, table, tuple([row[column.label] for column in selectable.primary_keys]))
         
-        if key is None:
-            key = self.get_key(row, class_, table)
-
+    def get(self, key):
         return self.map[key]
             
     
index 3d0ee0b8acdc31a1c095898d03ecd2352f36223c..a2fc249c75cdd93a455fb11c447ec356addab834 100644 (file)
@@ -99,6 +99,7 @@ class MapperTest(PersistTest):
         ), identitymap = mapper.IdentityMap())
         l = m.select()
         print repr(l)
+        print repr(m.identitymap.map)
 
         
     def tearDown(self):