From 2cb2d6bc1be320d4f4b395276f8ce0b9d56c39b6 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 29 Aug 2005 00:07:30 +0000 Subject: [PATCH] --- lib/sqlalchemy/mapper.py | 125 ++++++++++++++++++--------------------- test/mapper.py | 30 ++++------ 2 files changed, 69 insertions(+), 86 deletions(-) diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index bef7474f24..62a1bb8567 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -96,7 +96,12 @@ class Mapper(object): self.props = {} for column in self.selectable.columns: - self.props[column.key] = ColumnProperty(column) + prop = self.props.get(column.key, None) + if prop is None: + prop = ColumnProperty(column) + self.props[column.key] = prop + else: + prop.columns.append(column) self.properties = properties if properties is not None: for key, value in properties.iteritems(): @@ -209,18 +214,45 @@ class Mapper(object): """ if getattr(obj, 'dirty', True): - f = def(): + def foo(): + props = {} + for prop in self.props.values(): + if not isinstance(prop, ColumnProperty): + continue + for col in prop.columns: + props[col] = prop for table in self.tables: - for col in table.columns: - if getattr(obj, col.key, None) is None: - self.insert(obj, table) + params = {} + for primary_key in table.primary_keys: + if props[primary_key].getattr(obj) is None: + statement = table.insert() + for col in table.columns: + params[col.key] = props[col].getattr(obj) break else: - self.update(obj, table) - + clause = sql.and_() + for col in table.columns: + if col.primary_key: + clause.clauses.append(col == props[col].getattr(obj)) + else: + params[col.key] = props[col].getattr(obj) + statement = table.update(clause) + statement.echo = self.echo + statement.execute(**params) + if isinstance(statement, sql.Insert): + primary_keys = table.engine.last_inserted_ids() + index = 0 + for col in table.primary_keys: + newid = primary_keys[index] + index += 1 + props[col].setattr(obj, newid) + self.put(obj) + # unset dirty flag + obj.dirty = False for prop in self.props.values(): - prop.save(obj, traverse, refetch) - self.transaction(f) + if not isinstance(prop, ColumnProperty): + prop.save(obj, traverse, refetch) + self.transaction(foo) else: for prop in self.props.values(): prop.save(obj, traverse, refetch) @@ -232,52 +264,6 @@ class Mapper(object): """removes the object. traverse indicates attached objects should be removed as well.""" pass - def insert(self, obj, table = None): - """inserts an object into one table, regardless of primary key being set. this is a - lower-level operation than save.""" - - if table is None: - table = self.table - - params = {} - for col in table.columns: - params[col.key] = getattr(obj, col.key, None) - ins = table.insert() - ins.echo = self.echo - ins.execute(**params) - - # unset dirty flag - obj.dirty = False - - # populate new primary keys - primary_keys = table.engine.last_inserted_ids() - index = 0 - for pk in table.primary_keys: - newid = primary_keys[index] - index += 1 - # TODO: do this via the ColumnProperty objects - setattr(obj, pk.key, newid) - - self.put(obj) - - def update(self, obj, table = None): - """updates an object in one table, regardless of primary key being set. this is a - lower-level operation than save.""" - - if table is None: - table = self.table - params = {} - clause = sql.and_() - for col in table.columns: - if col.primary_key: - clause.clauses.append(col == getattr(obj, col.key)) - else: - params[col.key] = getattr(obj, col.key) - upd = table.update(clause) - upd.echo = self.echo - upd.execute(**params) - # unset dirty flag - obj.dirty = False def delete(self, obj): """deletes the object's row from its table unconditionally. this is a lower-level @@ -392,11 +378,15 @@ class MapperProperty: class ColumnProperty(MapperProperty): """describes an object attribute that corresponds to a table column.""" - def __init__(self, column): - self.column = column + def __init__(self, *columns): + self.columns = list(columns) + def getattr(self, object): + return getattr(object, self.key, None) + def setattr(self, object, value): + setattr(object, self.key, value) def hash_key(self): - return "ColumnProperty(%s)" % hash_key(self.column) + return "ColumnProperty(%s)" % repr([hash_key(c) for c in self.columns]) def init(self, key, parent, root): self.key = key @@ -450,19 +440,20 @@ class PropertyLoader(MapperProperty): # if a mapping table exists, determine the two foreign key columns # in the mapping table, set the two values, and insert that row, for # each row in the list - for child in getattr(obj, self.key): - setter = ForeignKeySetter(obj, child) - self.primaryjoin.accept_visitor(setter) - self.mapper.save(child) + # for child in getattr(obj, self.key): + # setter = ForeignKeySetter(obj, child) + # self.primaryjoin.accept_visitor(setter) + # self.mapper.save(child) + pass def delete(self): self.mapper.delete() -class ForeignKeySetter(ClauseVisitor): - def visit_binary(self, binary): - if binary.operator == '==': - if binary.left.table == self.primarytable and binary.right.table == self.secondarytable: - setattr(self.child, binary.left.colname, getattr(obj, binary.right.colname)) +#class ForeignKeySetter(ClauseVisitor): + # def visit_binary(self, binary): + # if binary.operator == '==': + # if binary.left.table == self.primarytable and binary.right.table == self.secondarytable: + # setattr(self.child, binary.left.colname, getattr(obj, binary.right.colname)) class LazyLoader(PropertyLoader): diff --git a/test/mapper.py b/test/mapper.py index d1f29e4389..e8ed519f6f 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -232,14 +232,6 @@ class EagerTest(PersistTest): class SaveTest(PersistTest): - def testinsert(self): - u = User() - u.user_name = 'inserttester' - m = mapper(User, users, echo=True) - m.insert(u) -# nu = m.get(u.user_id) - nu = m.select(users.c.user_id == u.user_id)[0] - self.assert_(u is nu) def testsave(self): # save two users @@ -274,24 +266,24 @@ class SaveTest(PersistTest): def testsavemultitable(self): usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) - m = mapper(User, usersaddresses, table = users) + m = mapper(User, usersaddresses, table = users, echo = True, properties = dict(email = ColumnProperty(addresses.c.email_address), foo_id = ColumnProperty(users.c.user_id, addresses.c.user_id))) u = User() u.user_name = 'multitester' - u.email_address = 'multi@test.org' + u.email = 'multi@test.org' m.save(u) - usertable = engine.ResultProxy(users.select().execute()).fetchall() - print repr(usertable) - addresstable = engine.ResultProxy(addresses.select().execute()).fetchall() - print repr(addresstable) + usertable = engine.ResultProxy(users.select(users.c.user_id.in_(10)).execute()).fetchall() + self.assert_(usertable[0].row == (10, 'multitester')) + addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(4)).execute()).fetchall() + self.assert_(addresstable[0].row == (4, 10, 'multi@test.org')) - u.email_address = 'lala@hey.com' + u.email = 'lala@hey.com' u.user_name = 'imnew' m.save(u) - usertable = engine.ResultProxy(users.select().execute()).fetchall() - print repr(usertable) - addresstable = engine.ResultProxy(addresses.select().execute()).fetchall() - print repr(addresstable) + usertable = engine.ResultProxy(users.select(users.c.user_id.in_(10)).execute()).fetchall() + self.assert_(usertable[0].row == (10, 'imnew')) + addresstable = engine.ResultProxy(addresses.select(addresses.c.address_id.in_(4)).execute()).fetchall() + self.assert_(addresstable[0].row == (4, 10, 'lala@hey.com')) if __name__ == "__main__": unittest.main() -- 2.47.2