From ea190e295e29b27789ccca3b59dd002658748e5f Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 4 Sep 2005 18:48:26 +0000 Subject: [PATCH] --- lib/sqlalchemy/mapper.py | 57 ++++++++++++++++++++-------------------- test/mapper.py | 57 ++++++++++++++++++++++++---------------- test/tables.py | 4 +-- 3 files changed, 66 insertions(+), 52 deletions(-) diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 46dfde3e84..2b981e647e 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -216,7 +216,7 @@ class Mapper(object): def _setattrbycolumn(self, obj, column, value): self.columntoproperty[column][0].setattr(obj, value) - def save(self, obj, traverse = True, refetch = False): + def save(self, obj, traverse = True): """saves the object across all its primary tables. based on the existence of the primary key for each table, either inserts or updates. primary key is determined by the underlying database engine's sequence methodology. @@ -229,6 +229,8 @@ class Mapper(object): if getattr(obj, 'dirty', True): def foo(): + insert_statement = None + update_statement = None for table in self.tables: params = {} # TODO: prepare the insert() and update() - (1) within the code or @@ -264,11 +266,11 @@ class Mapper(object): obj.dirty = False for prop in self.props.values(): if not isinstance(prop, ColumnProperty): - prop.save(obj, traverse, refetch) + prop.save(obj, traverse) self.transaction(foo) else: for prop in self.props.values(): - prop.save(obj, traverse, refetch) + prop.save(obj, traverse) def transaction(self, f): return self.table.engine.multi_transaction(self.tables, f) @@ -277,15 +279,6 @@ class Mapper(object): """removes the object. traverse indicates attached objects should be removed as well.""" pass - def delete(self, obj): - """deletes the object's row from its table unconditionally. this is a lower-level - operation than remove.""" - # delete dependencies ? - # delete row - # remove primary keys - # unset dirty flag - pass - def _compile(self, whereclause = None, **options): statement = sql.select([self.selectable], whereclause) for key, value in self.props.iteritems(): @@ -365,7 +358,7 @@ class MapperProperty: """called when the MapperProperty is first attached to a new parent Mapper.""" pass - def save(self, object, traverse, refetch): + def save(self, object, traverse): """called when the instance is being saved""" pass @@ -432,48 +425,56 @@ class PropertyLoader(MapperProperty): if self.primaryjoin is None: self.primaryjoin = match_primaries(parent.selectable, self.target) - def save(self, obj, traverse, refetch): + def save(self, obj, traverse): # saves child objects - # TODO: put association table inserts/deletes into one batch - #if self.secondary is not None: - # secondary_delete = self.secondary.delete(sql.and_([c == bindparam(c.key) for c in setter.secondary.c])) - + if self.secondary is not None: + secondary_delete = [] + secondary_insert = [] + setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary, obj) childlist = getattr(obj, self.key) if not isinstance(childlist, util.HistoryArraySet): childlist = util.HistoryArraySet(childlist) clean_setattr(obj, self.key, childlist) + for child in childlist.deleted_items(): setter.child = child + setter.associationrow = {} setter.clearkeys = True self.primaryjoin.accept_visitor(setter) child.dirty = True - self.mapper.save(child) + self.mapper.save(child, traverse) if self.secondary is not None: self.secondaryjoin.accept_visitor(setter) - # TODO: prepare this above - statement = self.secondary.delete(sql.and_(*[c == setter.associationrow[c.key] for c in self.secondary.c])) - statement.echo = self.mapper.echo - statement.execute() + secondary_delete.append(setter.associationrow) + for child in childlist.added_items(): setter.child = child + setter.associationrow = {} self.primaryjoin.accept_visitor(setter) child.dirty = True - self.mapper.save(child) + self.mapper.save(child, traverse) if self.secondary is not None: self.secondaryjoin.accept_visitor(setter) - # TODO: prepare this above + secondary_insert.append(setter.associationrow) + + if self.secondary is not None: + if len(secondary_delete): + statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key) for c in self.secondary.c])) + statement.echo = self.mapper.echo + statement.execute(*secondary_delete) + if len(secondary_insert): statement = self.secondary.insert() statement.echo = self.mapper.echo - statement.execute(**setter.associationrow) + statement.execute(*secondary_insert) + for child in childlist.unchanged_items(): - self.mapper.save(child) + self.mapper.save(child, traverse) # TODO: if transaction fails state is invalid # use unit of work ? childlist.clear_history() - def delete(self): self.mapper.delete() diff --git a/test/mapper.py b/test/mapper.py index 7f129cf945..ba37335ace 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -2,6 +2,7 @@ from testbase import PersistTest import unittest, sys, os from sqlalchemy.mapper import * +ECHO = False execfile("test/tables.py") class User(object): @@ -58,7 +59,7 @@ class MapperTest(AssertMixin): #globalidentity().clear() def testget(self): - m = mapper(User, users, scope = "thread", echo = True) + m = mapper(User, users, scope = "thread") self.assert_(m.get(19) is None) u = m.get(7) u2 = m.get(7) @@ -85,7 +86,7 @@ class MapperTest(AssertMixin): """tests that a lazy relation can be upgraded to an eager relation via the options method""" m = mapper(User, users, properties = dict( addresses = relation(Address, addresses, lazy = True) - ), echo = True) + )) l = m.options(eagerload('addresses')).select() self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, @@ -97,7 +98,7 @@ class MapperTest(AssertMixin): """tests that an eager relation can be upgraded to a lazy relation via the options method""" m = mapper(User, users, properties = dict( addresses = relation(Address, addresses, lazy = False) - ), echo = True) + )) l = m.options(lazyload('addresses')).select() self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, @@ -114,7 +115,7 @@ class LazyTest(AssertMixin): """tests a basic one-to-many lazy load""" m = mapper(User, users, properties = dict( addresses = relation(Address, addresses, lazy = True) - ), echo = True) + )) l = m.select(users.c.user_id == 7) self.assert_result(l, User, {'user_id' : 7, 'addresses' : (Address, [{'address_id' : 1}])}, @@ -126,7 +127,7 @@ class LazyTest(AssertMixin): m = mapper(Item, items, properties = dict( keywords = relation(Keyword, keywords, itemkeywords, lazy = True), - ), echo = True) + )) l = m.select() self.assert_result(l, Item, {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, @@ -156,7 +157,7 @@ class EagerTest(PersistTest): m = mapper(User, users, properties = dict( #addresses = relation(Address, addresses, lazy = False), addresses = relation(m, lazy = False), - ), echo = True) + )) l = m.select() print repr(l) @@ -166,7 +167,7 @@ class EagerTest(PersistTest): criterion doesnt interfere with the eager load criterion.""" m = mapper(User, users, properties = dict( addresses = relation(Address, addresses, primaryjoin = users.c.user_id==addresses.c.user_id, lazy = False) - ), echo = True) + )) l = m.select(and_(addresses.c.email_address == 'ed@lala.com', addresses.c.user_id==users.c.user_id)) print repr(l) @@ -199,7 +200,7 @@ class EagerTest(PersistTest): m = mapper(User, users, properties = dict( orders_open = relation(Order, openorders, primaryjoin = and_(openorders.c.isopen == 1, users.c.user_id==openorders.c.user_id), lazy = False), orders_closed = relation(Order, closedorders, primaryjoin = and_(closedorders.c.isopen == 0, users.c.user_id==closedorders.c.user_id), lazy = False) - ), echo = True) + )) l = m.select() print repr(l) @@ -213,7 +214,7 @@ class EagerTest(PersistTest): m = mapper(User, users, properties = dict( addresses = relation(Address, addresses, lazy = False), orders = relation(ordermapper, primaryjoin = users.c.user_id==orders.c.user_id, lazy = False), - ), echo = True) + )) l = m.select() print repr(l) @@ -222,7 +223,7 @@ class EagerTest(PersistTest): m = mapper(Item, items, properties = dict( keywords = relation(Keyword, keywords, itemkeywords, lazy = False), - ), echo = True) + )) l = m.select() print repr(l) @@ -235,17 +236,16 @@ class EagerTest(PersistTest): m = mapper(Item, items, properties = dict( keywords = relation(Keyword, keywords, itemkeywords, lazy = False), - ), - echo = True) + )) m = mapper(Order, orders, properties = dict( items = relation(m, lazy = False) - ), echo = True) + )) l = m.select("orders.order_id in (1,2,3)") #l = m.select() print repr(l) -class SaveTest(PersistTest): +class SaveTest(AssertMixin): def testbasic(self): # save two users @@ -253,7 +253,7 @@ class SaveTest(PersistTest): u.user_name = 'savetester' u2 = User() u2.user_name = 'savetester2' - m = mapper(User, users, echo=True) + m = mapper(User, users) m.save(u) m.save(u2) @@ -282,7 +282,7 @@ class SaveTest(PersistTest): """tests a save of an object where each instance spans two tables. also tests redefinition of the keynames for the column properties.""" usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) - m = mapper(User, usersaddresses, table = users, echo = True, + m = mapper(User, usersaddresses, table = users, properties = dict( email = ColumnProperty(addresses.c.email_address), foo_id = ColumnProperty(users.c.user_id, addresses.c.user_id) @@ -314,7 +314,7 @@ class SaveTest(PersistTest): """test basic save of one to many.""" m = mapper(User, users, properties = dict( addresses = relation(Address, addresses, lazy = True) - ), echo = True) + )) u = User() u.user_name = 'one2manytester' u.addresses = [] @@ -344,7 +344,7 @@ class SaveTest(PersistTest): """tests that an alias of a table can be used in a mapper. the mapper has to locate the original table and columns to keep it all straight.""" ualias = Alias(users, 'ualias') - m = mapper(User, ualias, echo = True) + m = mapper(User, ualias) u = User() u.user_name = 'testalias' m.save(u) @@ -355,7 +355,7 @@ class SaveTest(PersistTest): def testremove(self): m = mapper(User, users, properties = dict( addresses = relation(Address, addresses, lazy = True) - ), echo = True) + )) u = User() u.user_name = 'one2manytester' u.addresses = [] @@ -381,7 +381,7 @@ class SaveTest(PersistTest): m = mapper(Item, items, properties = dict( keywords = relation(Keyword, keywords, itemkeywords, lazy = False), - ), echo = True) + )) keywordmapper = mapper(Keyword, keywords) @@ -395,12 +395,25 @@ class SaveTest(PersistTest): for k in klist: item.keywords.append(k) m.save(item) - print repr(m.select(items.c.item_id == item.item_id)) + l = m.select(items.c.item_id == item.item_id) + + self.assert_result(l, Item, + {'item_id' : item.item_id, 'keywords' : (Keyword, [ + {'name' : 'purple'}, + {'name' : 'blue'}, + {'name' : 'big'}, + {'name' : 'round'} + ])}) del item.keywords[2] del item.keywords[2] m.save(item) - print repr(m.select(items.c.item_id == item.item_id)) + l = m.select(items.c.item_id == item.item_id) + self.assert_result(l, Item, + {'item_id' : item.item_id, 'keywords' : (Keyword, [ + {'name' : 'purple'}, + {'name' : 'blue'}, + ])}) if __name__ == "__main__": unittest.main() diff --git a/test/tables.py b/test/tables.py index bddb7b499d..3822ab9280 100644 --- a/test/tables.py +++ b/test/tables.py @@ -8,12 +8,12 @@ DBTYPE = 'sqlite_memory' if DBTYPE == 'sqlite_memory': import sqlalchemy.databases.sqlite as sqllite - db = sqllite.engine(':memory:', {}, echo = False) + db = sqllite.engine(':memory:', {}, echo = ECHO) elif DBTYPE == 'sqlite_file': import sqlalchemy.databases.sqlite as sqllite if os.access('querytest.db', os.F_OK): os.remove('querytest.db') - db = sqllite.engine('querytest.db', opts = {}, echo = True) + db = sqllite.engine('querytest.db', opts = {}, echo = ECHO) elif DBTYPE == 'postgres': pass -- 2.47.2