From a868c698c1ef9ae8abf124f35836fb611b283706 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 3 Sep 2005 04:27:55 +0000 Subject: [PATCH] --- lib/sqlalchemy/mapper.py | 72 +++++++++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 548ca612b4..17e48f7e34 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -94,18 +94,47 @@ class Mapper(object): else: self.identitymap = _global_identitymap + # object attribute names mapped to MapperProperty objects self.props = {} + + # table columns mapped to lists of MapperProperty objects + # using a list allows a single column to be defined as + # populating multiple object attributes + self.columntoproperty = {} + + # the original properties argument to match against similar + # arguments, for caching purposes + self.properties = properties + + # load custom properties + if self.properties is not None: + for key, prop in self.properties.iteritems(): + self.props[key] = prop + if isinstance(prop, ColumnProperty): + for col in prop.columns: + proplist = self.columntoproperty.setdefault(col, []) + proplist.append(prop) + + # load properties from the main Selectable object, + # not overriding those set up in the 'properties' argument for column in self.selectable.columns: + if self.columntoproperty.has_key(column): + continue + prop = self.props.get(column.key, None) if prop is None: prop = ColumnProperty(column) self.props[column.key] = prop - else: + elif isinstance(prop, ColumnProperty): prop.columns.append(column) - self.properties = properties - if properties is not None: - for key, value in properties.iteritems(): - self.props[key] = value + else: + continue + + # its a ColumnProperty - match the columns + # back to the property + proplist = self.columntoproperty.setdefault(column, []) + proplist.append(prop) + if isroot: self.init(self) @@ -202,6 +231,12 @@ class Mapper(object): else: return self._select_whereclause(arg, **params) + def _getattrbycolumn(self, obj, column): + return self.columntoproperty[column][0].getattr(obj) + + def _setattrbycolumn(self, obj, column, value): + self.columntoproperty[column][0].setattr(obj, value) + def save(self, obj, traverse = True, refetch = False): """saves the object across all its primary tables. based on the existence of the primary key for each table, either inserts or updates. @@ -215,27 +250,21 @@ class Mapper(object): if getattr(obj, 'dirty', True): def foo(): - props = {} - for prop in self.props.values(): - if not isinstance(prop, ColumnProperty): - continue - for col in prop.columns: - props[col] = prop for table in self.tables: params = {} for primary_key in table.primary_keys: - if props[primary_key].getattr(obj) is None: + if self._getattrbycolumn(obj, primary_key) is None: statement = table.insert() for col in table.columns: - params[col.key] = props[col].getattr(obj) + params[col.key] = self._getattrbycolumn(obj, col) break else: clause = sql.and_() for col in table.columns: if col.primary_key: - clause.clauses.append(col == props[col].getattr(obj)) + clause.clauses.append(col == self._getattrbycolumn(obj, col)) else: - params[col.key] = props[col].getattr(obj) + params[col.key] = self._getattrbycolumn(obj, col) statement = table.update(clause) statement.echo = self.echo statement.execute(**params) @@ -245,10 +274,11 @@ class Mapper(object): for col in table.primary_keys: newid = primary_keys[index] index += 1 - props[col].setattr(obj, newid) + self._setattrbycolumn(obj, col, newid) self.put(obj) - # unset dirty flag - obj.dirty = False + # unset dirty flag, if the object defines one + if hasattr(obj, 'dirty'): + obj.dirty = False for prop in self.props.values(): if not isinstance(prop, ColumnProperty): prop.save(obj, traverse, refetch) @@ -379,6 +409,8 @@ class MapperProperty: class ColumnProperty(MapperProperty): """describes an object attribute that corresponds to a table column.""" def __init__(self, *columns): + """the list of columns describes a single object property populating + multiple columns, typcially across multiple tables""" self.columns = list(columns) def getattr(self, object): @@ -400,9 +432,9 @@ class ColumnProperty(MapperProperty): def execute(self, instance, row, identitykey, localmap, isduplicate): if not isduplicate: if self.use_smart: - instance.__dict__[self.key] = row[self.column.label] + instance.__dict__[self.key] = row[self.columns[0].label] else: - setattr(instance, self.key, row[self.column.label]) + setattr(instance, self.key, row[self.columns[0].label]) -- 2.47.2