]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Sep 2005 04:27:55 +0000 (04:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 3 Sep 2005 04:27:55 +0000 (04:27 +0000)
lib/sqlalchemy/mapper.py

index 548ca612b4ea3fdf57387a5cb8bfe7a4a6aef7ac..17e48f7e34f9103b9295efecf930aa73f103f62b 100644 (file)
@@ -94,18 +94,47 @@ class Mapper(object):
         else:
             self.identitymap = _global_identitymap
 
+        # object attribute names mapped to MapperProperty objects
         self.props = {}
+        
+        # table columns mapped to lists of MapperProperty objects
+        # using a list allows a single column to be defined as 
+        # populating multiple object attributes
+        self.columntoproperty = {}
+        
+        # the original properties argument to match against similar 
+        # arguments, for caching purposes
+        self.properties = properties
+
+        # load custom properties 
+        if self.properties is not None:
+            for key, prop in self.properties.iteritems():
+                self.props[key] = prop
+                if isinstance(prop, ColumnProperty):
+                    for col in prop.columns:
+                        proplist = self.columntoproperty.setdefault(col, [])
+                        proplist.append(prop)
+
+        # load properties from the main Selectable object,
+        # not overriding those set up in the 'properties' argument
         for column in self.selectable.columns:
+            if self.columntoproperty.has_key(column):
+                continue
+                
             prop = self.props.get(column.key, None)
             if prop is None:
                 prop = ColumnProperty(column)
                 self.props[column.key] = prop
-            else:
+            elif isinstance(prop, ColumnProperty):
                 prop.columns.append(column)
-        self.properties = properties
-        if properties is not None:
-            for key, value in properties.iteritems():
-                self.props[key] = value
+            else:
+                continue
+        
+            # its a ColumnProperty - match the columns
+            # back to the property
+            proplist = self.columntoproperty.setdefault(column, [])
+            proplist.append(prop)
+
 
         if isroot:
             self.init(self)
@@ -202,6 +231,12 @@ class Mapper(object):
         else:
             return self._select_whereclause(arg, **params)
 
+    def _getattrbycolumn(self, obj, column):
+        return self.columntoproperty[column][0].getattr(obj)
+
+    def _setattrbycolumn(self, obj, column, value):
+        self.columntoproperty[column][0].setattr(obj, value)
+        
     def save(self, obj, traverse = True, refetch = False):
         """saves the object across all its primary tables.  
         based on the existence of the primary key for each table, either inserts or updates.
@@ -215,27 +250,21 @@ class Mapper(object):
 
         if getattr(obj, 'dirty', True):
             def foo():
-                props = {}
-                for prop in self.props.values():
-                    if not isinstance(prop, ColumnProperty):
-                        continue
-                    for col in prop.columns:
-                        props[col] = prop
                 for table in self.tables:
                     params = {}
                     for primary_key in table.primary_keys:
-                        if props[primary_key].getattr(obj) is None:
+                        if self._getattrbycolumn(obj, primary_key) is None:
                             statement = table.insert()
                             for col in table.columns:
-                                params[col.key] = props[col].getattr(obj)
+                                params[col.key] = self._getattrbycolumn(obj, col)
                             break
                     else:
                         clause = sql.and_()
                         for col in table.columns:
                             if col.primary_key:
-                                clause.clauses.append(col == props[col].getattr(obj))
+                                clause.clauses.append(col == self._getattrbycolumn(obj, col))
                             else:
-                                params[col.key] = props[col].getattr(obj)
+                                params[col.key] = self._getattrbycolumn(obj, col)
                         statement = table.update(clause)
                     statement.echo = self.echo
                     statement.execute(**params)
@@ -245,10 +274,11 @@ class Mapper(object):
                         for col in table.primary_keys:
                             newid = primary_keys[index]
                             index += 1
-                            props[col].setattr(obj, newid)
+                            self._setattrbycolumn(obj, col, newid)
                         self.put(obj)
-                # unset dirty flag
-                obj.dirty = False
+                # unset dirty flag, if the object defines one
+                if hasattr(obj, 'dirty'):
+                    obj.dirty = False
                 for prop in self.props.values():
                     if not isinstance(prop, ColumnProperty):
                         prop.save(obj, traverse, refetch)
@@ -379,6 +409,8 @@ class MapperProperty:
 class ColumnProperty(MapperProperty):
     """describes an object attribute that corresponds to a table column."""
     def __init__(self, *columns):
+        """the list of columns describes a single object property populating 
+        multiple columns, typcially across multiple tables"""
         self.columns = list(columns)
 
     def getattr(self, object):
@@ -400,9 +432,9 @@ class ColumnProperty(MapperProperty):
     def execute(self, instance, row, identitykey, localmap, isduplicate):
         if not isduplicate:
             if self.use_smart:
-                instance.__dict__[self.key] = row[self.column.label]
+                instance.__dict__[self.key] = row[self.columns[0].label]
             else:
-                setattr(instance, self.key, row[self.column.label])
+                setattr(instance, self.key, row[self.columns[0].label])