From: Mike Bayer Date: Sun, 17 Feb 2008 18:13:14 +0000 (+0000) Subject: - modernized cascade.py tests X-Git-Tag: rel_0_4_4~67 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3e6e61dbe7b13be851e78db4280565f84c874c35;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - modernized cascade.py tests - your cries have been heard: removing a pending item from an attribute or collection with delete-orphan expunges the item from the session; no FlushError is raised. Note that if you session.save()'ed the pending item explicitly, the attribute/collection removal still knocks it out. --- diff --git a/CHANGES b/CHANGES index 56b3599fb2..cd8c1de4a7 100644 --- a/CHANGES +++ b/CHANGES @@ -8,6 +8,12 @@ CHANGES work properly with self-referential relations - the clause inside the EXISTS is aliased on the "remote" side to distinguish it from the parent table. + - your cries have been heard: removing a pending item + from an attribute or collection with delete-orphan + expunges the item from the session; no FlushError is raised. + Note that if you session.save()'ed the pending item + explicitly, the attribute/collection removal still knocks + it out. 0.4.4 ------ diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index d9131da416..0f87d73087 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -33,7 +33,7 @@ class UOWEventHandler(interfaces.AttributeExtension): session cascade operations. """ - def __init__(self, key, class_, cascade=None): + def __init__(self, key, class_, cascade): self.key = key self.class_ = class_ self.cascade = cascade @@ -41,27 +41,34 @@ class UOWEventHandler(interfaces.AttributeExtension): def append(self, obj, item, initiator): # process "save_update" cascade rules for when an instance is appended to the list of another instance sess = object_session(obj) - if sess is not None: - if self.cascade is not None and self.cascade.save_update and item not in sess: + if sess: + if self.cascade.save_update and item not in sess: mapper = object_mapper(obj) prop = mapper.get_property(self.key) ename = prop.mapper.entity_name sess.save_or_update(item, entity_name=ename) def remove(self, obj, item, initiator): - # currently no cascade rules for removing an item from a list - # (i.e. it stays in the Session) - pass + sess = object_session(obj) + if sess: + # expunge pending orphans + if self.cascade.delete_orphan and item in sess.new: + sess.expunge(item) def set(self, obj, newvalue, oldvalue, initiator): # process "save_update" cascade rules for when an instance is attached to another instance + if oldvalue is newvalue: + return sess = object_session(obj) - if sess is not None: - if newvalue is not None and self.cascade is not None and self.cascade.save_update and newvalue not in sess: + if sess: + if newvalue is not None and self.cascade.save_update and newvalue not in sess: mapper = object_mapper(obj) prop = mapper.get_property(self.key) ename = prop.mapper.entity_name sess.save_or_update(newvalue, entity_name=ename) + if self.cascade.delete_orphan and oldvalue in sess.new: + sess.expunge(oldvalue) + def register_attribute(class_, key, *args, **kwargs): """overrides attributes.register_attribute() to add UOW event handlers diff --git a/test/orm/cascade.py b/test/orm/cascade.py index 7222655776..e2b2c61227 100644 --- a/test/orm/cascade.py +++ b/test/orm/cascade.py @@ -3,106 +3,44 @@ import testenv; testenv.configure_for_tests() from sqlalchemy import * from sqlalchemy import exceptions from sqlalchemy.orm import * -from sqlalchemy.ext.sessioncontext import SessionContext from testlib import * -import testlib.tables as tables +from testlib import fixtures -class O2MCascadeTest(TestBase, AssertsExecutionResults): - def tearDown(self): - tables.delete() +class O2MCascadeTest(fixtures.FixtureTest): + keep_mappers = True + keep_data = False + refresh_data = False - def tearDownAll(self): - clear_mappers() - tables.drop() - - def setUpAll(self): - global data - tables.create() - mapper(tables.User, tables.users, properties = dict( - address = relation(mapper(tables.Address, tables.addresses), lazy=True, uselist = False, cascade="all, delete-orphan"), + def setup_mappers(self): + global User, Address, Order, users, orders, addresses + from testlib.fixtures import User, Address, Order, users, orders, addresses + + mapper(Address, addresses) + mapper(User, users, properties = dict( + addresses = relation(Address, cascade="all, delete-orphan"), orders = relation( - mapper(tables.Order, tables.orders, properties = dict ( - items = relation(mapper(tables.Item, tables.orderitems), lazy=True, uselist =True, cascade="all, delete-orphan") - )), - lazy = True, uselist = True, cascade="all, delete-orphan") + mapper(Order, orders), cascade="all, delete-orphan") )) - - def setUp(self): - global data - data = [tables.User, - {'user_name' : 'ed', - 'address' : (tables.Address, {'email_address' : 'foo@bar.com'}), - 'orders' : (tables.Order, [ - {'description' : 'eds 1st order', 'items' : (tables.Item, [{'item_name' : 'eds o1 item'}, {'item_name' : 'eds other o1 item'}])}, - {'description' : 'eds 2nd order', 'items' : (tables.Item, [{'item_name' : 'eds o2 item'}, {'item_name' : 'eds other o2 item'}])} - ]) - }, - {'user_name' : 'jack', - 'address' : (tables.Address, {'email_address' : 'jack@jack.com'}), - 'orders' : (tables.Order, [ - {'description' : 'jacks 1st order', 'items' : (tables.Item, [{'item_name' : 'im a lumberjack'}, {'item_name' : 'and im ok'}])} - ]) - }, - {'user_name' : 'foo', - 'address' : (tables.Address, {'email_address': 'hi@lala.com'}), - 'orders' : (tables.Order, [ - {'description' : 'foo order', 'items' : (tables.Item, [])}, - {'description' : 'foo order 2', 'items' : (tables.Item, [{'item_name' : 'hi'}])}, - {'description' : 'foo order three', 'items' : (tables.Item, [{'item_name' : 'there'}])} - ]) - } - ] - + + def test_list_assignment(self): sess = create_session() - for elem in data[1:]: - u = tables.User() - sess.save(u) - u.user_name = elem['user_name'] - u.address = tables.Address() - u.address.email_address = elem['address'][1]['email_address'] - u.orders = [] - for order in elem['orders'][1]: - o = tables.Order() - o.isopen = None - o.description = order['description'] - u.orders.append(o) - o.items = [] - for item in order['items'][1]: - i = tables.Item() - i.item_name = item['item_name'] - o.items.append(i) - - sess.flush() - sess.clear() - - def testassignlist(self): - sess = create_session() - u = tables.User() - u.user_name = 'jack' - o1 = tables.Order() - o1.description ='someorder' - o2 = tables.Order() - o2.description = 'someotherorder' - l = [o1, o2] + u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')]) sess.save(u) - u.orders = l - assert o1 in sess - assert o2 in sess sess.flush() sess.clear() - - u = sess.query(tables.User).get(u.user_id) - o3 = tables.Order() - o3.description='order3' - o4 = tables.Order() - o4.description = 'order4' - u.orders = [o3, o4] - assert o3 in sess - assert o4 in sess + + u = sess.query(User).get(u.id) + self.assertEquals(u, User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')])) + + u.orders=[Order(description="order 3"), Order(description="order 4")] sess.flush() + sess.clear() + + u = sess.query(User).get(u.id) + self.assertEquals(u, User(name='jack', orders=[Order(description="order 3"), Order(description="order 4")])) - o5 = tables.Order() - o5.description='order5' + self.assertEquals(sess.query(Order).all(), [Order(description="order 3"), Order(description="order 4")]) + o5 = Order(description="order 5") sess.save(o5) try: sess.flush() @@ -110,108 +48,106 @@ class O2MCascadeTest(TestBase, AssertsExecutionResults): except exceptions.FlushError, e: assert "is an orphan" in str(e) - - def testdelete(self): + def test_delete(self): sess = create_session() - l = sess.query(tables.User).all() - for u in l: - print repr(u.orders) - self.assert_result(l, data[0], *data[1:]) - - ids = (l[0].user_id, l[2].user_id) - sess.delete(l[0]) - sess.delete(l[2]) + u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')]) + sess.save(u) + sess.flush() + sess.delete(u) sess.flush() - assert tables.orders.count(tables.orders.c.user_id.in_(ids)).scalar() == 0 - assert tables.orderitems.count(tables.orders.c.user_id.in_(ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0 - assert tables.addresses.count(tables.addresses.c.user_id.in_(ids)).scalar() == 0 - assert tables.users.count(tables.users.c.user_id.in_(ids)).scalar() == 0 + assert users.count().scalar() == 0 + assert orders.count().scalar() == 0 - def testdelete2(self): + def test_delete_unloaded_collections(self): """test that unloaded collections are still included in a delete-cascade by default.""" sess = create_session() - u = sess.query(tables.User).filter_by(user_name='ed').one() - # assert 'addresses' collection not loaded + u = User(name='jack', addresses=[Address(email_address="address1"), Address(email_address="address2")]) + sess.save(u) + sess.flush() + sess.clear() + assert addresses.count().scalar() == 2 + assert users.count().scalar() == 1 + + u = sess.query(User).get(u.id) + assert 'addresses' not in u.__dict__ sess.delete(u) sess.flush() - assert tables.addresses.count(tables.addresses.c.email_address=='foo@bar.com').scalar() == 0 - assert tables.orderitems.count(tables.orderitems.c.item_name.like('eds%')).scalar() == 0 + assert addresses.count().scalar() == 0 + assert users.count().scalar() == 0 - def testcascadecollection(self): + def test_cascades_onlycollection(self): """test that cascade only reaches instances that are still part of the collection, not those that have been removed""" - sess = create_session() - u = tables.User() - u.user_name = 'newuser' - o = tables.Order() - o.description = "some description" - u.orders.append(o) + sess = create_session() + u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')]) sess.save(u) sess.flush() - - u.orders.remove(o) + + o = u.orders[0] + del u.orders[0] sess.delete(u) assert u in sess.deleted assert o not in sess.deleted + assert o in sess - - def testorphan(self): + u2 = User(name='newuser', orders=[o]) + sess.save(u2) + sess.flush() + sess.clear() + assert users.count().scalar() == 1 + assert orders.count().scalar() == 1 + self.assertEquals(sess.query(User).all(), [User(name='newuser', orders=[Order(description='someorder')])]) + + def test_collection_orphans(self): sess = create_session() - l = sess.query(tables.User).all() - jack = l[1] - jack.orders[:] = [] + u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')]) + sess.save(u) + sess.flush() - ids = [jack.user_id] - self.assert_(tables.orders.count(tables.orders.c.user_id.in_(ids)).scalar() == 1) - self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 2) + assert users.count().scalar() == 1 + assert orders.count().scalar() == 2 - sess.flush() + u.orders[:] = [] - self.assert_(tables.orders.count(tables.orders.c.user_id.in_(ids)).scalar() == 0) - self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(ids) &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0) + sess.flush() + assert users.count().scalar() == 1 + assert orders.count().scalar() == 0 -class M2OCascadeTest(TestBase, AssertsExecutionResults): - def tearDown(self): - ctx.current.clear() - for t in metadata.table_iterator(reverse=True): - t.delete().execute() - def tearDownAll(self): - clear_mappers() - metadata.drop_all() +class M2OCascadeTest(ORMTest): + keep_mappers = True + + def define_tables(self, metadata): + global extra, prefs, users - @testing.uses_deprecated('SessionContext') - def setUpAll(self): - global ctx, data, metadata, User, Pref, Extra - ctx = SessionContext(create_session) - metadata = MetaData(testing.db) extra = Table("extra", metadata, - Column("extra_id", Integer, Sequence("extra_id_seq", optional=True), primary_key=True), - Column("prefs_id", Integer, ForeignKey("prefs.prefs_id")) + Column("id", Integer, Sequence("extra_id_seq", optional=True), primary_key=True), + Column("prefs_id", Integer, ForeignKey("prefs.id")) ) prefs = Table('prefs', metadata, - Column('prefs_id', Integer, Sequence('prefs_id_seq', optional=True), primary_key=True), - Column('prefs_data', String(40))) + Column('id', Integer, Sequence('prefs_id_seq', optional=True), primary_key=True), + Column('data', String(40))) users = Table('users', metadata, - Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), - Column('user_name', String(40)), - Column('pref_id', Integer, ForeignKey('prefs.prefs_id')) + Column('id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), + Column('name', String(40)), + Column('pref_id', Integer, ForeignKey('prefs.id')) ) - class User(object): - def __init__(self, name): - self.user_name = name - class Pref(object): - def __init__(self, data): - self.prefs_data = data - class Extra(object): + + def setup_mappers(self): + global User, Pref, Extra + class User(fixtures.Base): pass - metadata.create_all() + class Pref(fixtures.Base): + pass + class Extra(fixtures.Base): + pass + mapper(Extra, extra) mapper(Pref, prefs, properties=dict( extra = relation(Extra, cascade="all, delete") @@ -221,64 +157,73 @@ class M2OCascadeTest(TestBase, AssertsExecutionResults): )) def setUp(self): - u1 = User("ed") - u1.pref = Pref("pref 1") - u2 = User("jack") - u2.pref = Pref("pref 2") - u3 = User("foo") - u3.pref = Pref("pref 3") - u1.pref.extra.append(Extra()) - u2.pref.extra.append(Extra()) - u2.pref.extra.append(Extra()) - - ctx.current.save(u1) - ctx.current.save(u2) - ctx.current.save(u3) - ctx.current.flush() - ctx.current.clear() + u1 = User(name='ed', pref=Pref(data="pref 1", extra=[Extra()])) + u2 = User(name='jack', pref=Pref(data="pref 2", extra=[Extra()])) + u3 = User(name="foo", pref=Pref(data="pref 3", extra=[Extra()])) + sess = create_session() + sess.save(u1) + sess.save(u2) + sess.save(u3) + sess.flush() + sess.close() @testing.fails_on('maxdb') - def testorphan(self): - jack = ctx.current.query(User).filter_by(user_name='jack').one() - p = jack.pref - e = jack.pref.extra[0] + def test_orphan(self): + sess = create_session() + assert prefs.count().scalar() == 3 + assert extra.count().scalar() == 3 + jack = sess.query(User).filter_by(name="jack").one() jack.pref = None - ctx.current.flush() - assert p not in ctx.current - assert e not in ctx.current + sess.flush() + assert prefs.count().scalar() == 2 + assert extra.count().scalar() == 2 @testing.fails_on('maxdb') - def testorphan2(self): - jack = ctx.current.query(User).filter_by(user_name='jack').one() + def test_orphan_on_update(self): + sess = create_session() + jack = sess.query(User).filter_by(name="jack").one() p = jack.pref e = jack.pref.extra[0] - ctx.current.clear() + sess.clear() jack.pref = None - ctx.current.update(jack) - ctx.current.update(p) - ctx.current.update(e) - assert p in ctx.current - assert e in ctx.current - ctx.current.flush() - assert p not in ctx.current - assert e not in ctx.current - - def testorphan3(self): + sess.update(jack) + sess.update(p) + sess.update(e) + assert p in sess + assert e in sess + sess.flush() + assert prefs.count().scalar() == 2 + assert extra.count().scalar() == 2 + + def test_pending_expunge(self): + sess = create_session() + someuser = User(name='someuser') + sess.save(someuser) + sess.flush() + someuser.pref = p1 = Pref(data='somepref') + assert p1 in sess + someuser.pref = Pref(data='someotherpref') + assert p1 not in sess + sess.flush() + self.assertEquals(sess.query(Pref).with_parent(someuser).all(), [Pref(data="someotherpref")]) + + + def test_double_assignment(self): """test that double assignment doesn't accidentally reset the 'parent' flag.""" - jack = ctx.current.query(User).filter_by(user_name='jack').one() - newpref = Pref("newpref") + sess = create_session() + jack = sess.query(User).filter_by(name="jack").one() + + newpref = Pref(data="newpref") jack.pref = newpref jack.pref = newpref - ctx.current.flush() - - + sess.flush() + self.assertEquals(sess.query(Pref).all(), [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")]) -class M2MCascadeTest(TestBase, AssertsExecutionResults): - def setUpAll(self): - global metadata, a, b, atob - metadata = MetaData(testing.db) +class M2MCascadeTest(ORMTest): + def define_tables(self, metadata): + global a, b, atob a = Table('a', metadata, Column('id', Integer, primary_key=True), Column('data', String(30)) @@ -292,18 +237,12 @@ class M2MCascadeTest(TestBase, AssertsExecutionResults): Column('bid', Integer, ForeignKey('b.id')) ) - metadata.create_all() - - def tearDownAll(self): - metadata.drop_all() - def testdeleteorphan(self): - class A(object): - def __init__(self, data): - self.data = data - class B(object): - def __init__(self, data): - self.data = data + def test_delete_orphan(self): + class A(fixtures.Base): + pass + class B(fixtures.Base): + pass mapper(A, a, properties={ # if no backref here, delete-orphan failed until [ticket:427] was fixed @@ -312,9 +251,8 @@ class M2MCascadeTest(TestBase, AssertsExecutionResults): mapper(B, b) sess = create_session() - a1 = A('a1') - b1 = B('b1') - a1.bs.append(b1) + b1 = B(data='b1') + a1 = A(data='a1', bs=[b1]) sess.save(a1) sess.flush() @@ -324,13 +262,11 @@ class M2MCascadeTest(TestBase, AssertsExecutionResults): assert b.count().scalar() == 0 assert a.count().scalar() == 1 - def testcascadedelete(self): - class A(object): - def __init__(self, data): - self.data = data - class B(object): - def __init__(self, data): - self.data = data + def test_cascade_delete(self): + class A(fixtures.Base): + pass + class B(fixtures.Base): + pass mapper(A, a, properties={ 'bs':relation(B, secondary=atob, cascade="all, delete-orphan") @@ -338,9 +274,7 @@ class M2MCascadeTest(TestBase, AssertsExecutionResults): mapper(B, b) sess = create_session() - a1 = A('a1') - b1 = B('b1') - a1.bs.append(b1) + a1 = A(data='a1', bs=[B(data='b1')]) sess.save(a1) sess.flush() @@ -357,7 +291,7 @@ class UnsavedOrphansTest(ORMTest): global users, addresses, User, Address users = Table('users', metadata, Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True), - Column('user_name', String(40)), + Column('name', String(40)), ) addresses = Table('email_addresses', metadata, @@ -365,10 +299,10 @@ class UnsavedOrphansTest(ORMTest): Column('user_id', Integer, ForeignKey(users.c.user_id)), Column('email_address', String(40)), ) - class User(object):pass - class Address(object):pass + class User(fixtures.Base):pass + class Address(fixtures.Base):pass - def test_pending_orphan(self): + def test_pending_standalone_orphan(self): """test that an entity that never had a parent on a delete-orphan cascade cant be saved.""" mapper(Address, addresses) @@ -384,9 +318,8 @@ class UnsavedOrphansTest(ORMTest): pass assert a.address_id is None, "Error: address should not be persistent" - def test_delete_new_object(self): - """test that an entity which is attached then detached from its - parent with a delete-orphan cascade gets counted as an orphan""" + def test_pending_collection_expunge(self): + """test that removing a pending item from a collection expunges it from the session.""" mapper(Address, addresses) mapper(User, users, properties=dict( @@ -398,18 +331,33 @@ class UnsavedOrphansTest(ORMTest): s.save(u) s.flush() a = Address() - assert a not in s.new + u.addresses.append(a) + assert a in s + u.addresses.remove(a) - s.delete(u) - try: - s.flush() # (erroneously) causes "a" to be persisted - assert False - except exceptions.FlushError: - assert True - assert a.address_id is None, "Error: address should not be persistent" + assert a not in s + s.delete(u) + s.flush() + assert a.address_id is None, "Error: address should not be persistent" + + def test_nonorphans_ok(self): + mapper(Address, addresses) + mapper(User, users, properties=dict( + addresses=relation(Address, cascade="all,delete", backref="user") + )) + s = create_session() + u = User(name='u1', addresses=[Address(email_address='ad1')]) + s.save(u) + a1 = u.addresses[0] + u.addresses.remove(a1) + assert a1 in s + s.flush() + s.clear() + self.assertEquals(s.query(Address).all(), [Address(email_address='ad1')]) + class UnsavedOrphansTest2(ORMTest): """same test as UnsavedOrphans only three levels deep""" @@ -433,56 +381,52 @@ class UnsavedOrphansTest2(ORMTest): ) - def testdeletechildwithchild(self): - """test that an entity which is attached then detached from its - parent with a delete-orphan cascade gets counted as an orphan, as well - as its own child instances""" - - class Order(object): pass - class Item(object): pass - class Attribute(object): pass + def test_pending_expunge(self): + class Order(fixtures.Base): + pass + class Item(fixtures.Base): + pass + class Attribute(fixtures.Base): + pass - attrMapper = mapper(Attribute, attributes) - itemMapper = mapper(Item, items, properties=dict( - attributes=relation(attrMapper, cascade="all,delete-orphan", backref="item") + mapper(Attribute, attributes) + mapper(Item, items, properties=dict( + attributes=relation(Attribute, cascade="all,delete-orphan", backref="item") )) - orderMapper = mapper(Order, orders, properties=dict( - items=relation(itemMapper, cascade="all,delete-orphan", backref="order") + mapper(Order, orders, properties=dict( + items=relation(Item, cascade="all,delete-orphan", backref="order") )) - s = create_session( ) - order = Order() + s = create_session() + order = Order(name="order1") s.save(order) - item = Item() - attr = Attribute() - item.attributes.append(attr) + attr = Attribute(name="attr1") + item = Item(name="item1", attributes=[attr]) order.items.append(item) - order.items.remove(item) # item is an orphan, but attr is not so flush() tries to save attr - try: - s.flush() - assert False - except exceptions.FlushError, e: - print e - assert True - - assert item.id is None - assert attr.id is None + order.items.remove(item) + + assert item not in s + assert attr not in s + + s.flush() + assert orders.count().scalar() == 1 + assert items.count().scalar() == 0 + assert attributes.count().scalar() == 0 -class DoubleParentOrphanTest(TestBase, AssertsExecutionResults): +class DoubleParentOrphanTest(ORMTest): """test orphan detection for an entity with two parent relations""" - def setUpAll(self): - global metadata, address_table, businesses, homes - metadata = MetaData(testing.db) + def define_tables(self, metadata): + global address_table, businesses, homes address_table = Table('addresses', metadata, Column('address_id', Integer, primary_key=True), Column('street', String(30)), ) homes = Table('homes', metadata, - Column('home_id', Integer, primary_key=True), + Column('home_id', Integer, primary_key=True, key="id"), Column('description', String(30)), Column('address_id', Integer, ForeignKey('addresses.address_id'), nullable=False), ) @@ -492,38 +436,42 @@ class DoubleParentOrphanTest(TestBase, AssertsExecutionResults): Column('description', String(30), key="description"), Column('address_id', Integer, ForeignKey('addresses.address_id'), nullable=False), ) - metadata.create_all() - def tearDown(self): - clear_mappers() - def tearDownAll(self): - metadata.drop_all() + def test_non_orphan(self): """test that an entity can have two parent delete-orphan cascades, and persists normally.""" - class Address(object):pass - class Home(object):pass - class Business(object):pass + class Address(fixtures.Base): + pass + class Home(fixtures.Base): + pass + class Business(fixtures.Base): + pass + mapper(Address, address_table) mapper(Home, homes, properties={'address':relation(Address, cascade="all,delete-orphan")}) mapper(Business, businesses, properties={'address':relation(Address, cascade="all,delete-orphan")}) session = create_session() - a1 = Address() - a2 = Address() - h1 = Home() - b1 = Business() - h1.address = a1 - b1.address = a2 + h1 = Home(description='home1', address=Address(street='address1')) + b1 = Business(description='business1', address=Address(street='address2')) [session.save(x) for x in [h1,b1]] session.flush() + session.clear() + + self.assertEquals(session.query(Home).get(h1.id), Home(description='home1', address=Address(street='address1'))) + self.assertEquals(session.query(Business).get(b1.id), Business(description='business1', address=Address(street='address2'))) def test_orphan(self): """test that an entity can have two parent delete-orphan cascades, and is detected as an orphan when saved without a parent.""" - class Address(object):pass - class Home(object):pass - class Business(object):pass + class Address(fixtures.Base): + pass + class Home(fixtures.Base): + pass + class Business(fixtures.Base): + pass + mapper(Address, address_table) mapper(Home, homes, properties={'address':relation(Address, cascade="all,delete-orphan")}) mapper(Business, businesses, properties={'address':relation(Address, cascade="all,delete-orphan")}) @@ -537,60 +485,47 @@ class DoubleParentOrphanTest(TestBase, AssertsExecutionResults): except exceptions.FlushError, e: assert True -class CollectionAssignmentOrphanTest(TestBase, AssertsExecutionResults): - def setUpAll(self): - global metadata, table_a, table_b +class CollectionAssignmentOrphanTest(ORMTest): + def define_tables(self, metadata): + global table_a, table_b - metadata = MetaData(testing.db) table_a = Table('a', metadata, Column('id', Integer, primary_key=True), - Column('foo', String(30))) + Column('name', String(30))) table_b = Table('b', metadata, Column('id', Integer, primary_key=True), - Column('foo', String(30)), + Column('name', String(30)), Column('a_id', Integer, ForeignKey('a.id'))) - metadata.create_all() - - def tearDown(self): - clear_mappers() - def tearDownAll(self): - metadata.drop_all() def test_basic(self): - class A(object): - def __init__(self, foo): - self.foo = foo - class B(object): - def __init__(self, foo): - self.foo = foo + class A(fixtures.Base): + pass + class B(fixtures.Base): + pass mapper(A, table_a, properties={ 'bs':relation(B, cascade="all, delete-orphan") }) mapper(B, table_b) - a1 = A('a1') - a1.bs.append(B('b1')) - a1.bs.append(B('b2')) - a1.bs.append(B('b3')) + a1 = A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')]) sess = create_session() sess.save(a1) sess.flush() - assert table_b.count(table_b.c.a_id == None).scalar() == 0 - - assert table_b.count().scalar() == 3 + sess.clear() + + self.assertEquals(sess.query(A).get(a1.id), A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])) a1 = sess.query(A).get(a1.id) - assert len(a1.bs) == 3 - a1.bs = list(a1.bs) assert not class_mapper(B)._is_orphan(a1.bs[0]) a1.bs[0].foo='b2modified' a1.bs[1].foo='b3modified' sess.flush() - assert table_b.count().scalar() == 3 + sess.clear() + self.assertEquals(sess.query(A).get(a1.id), A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])) if __name__ == "__main__": testenv.main() diff --git a/test/orm/query.py b/test/orm/query.py index a2ac6cf4e0..d48abd5d2e 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -13,10 +13,6 @@ class QueryTest(FixtureTest): keep_mappers = True keep_data = True - def setUpAll(self): - super(QueryTest, self).setUpAll() - self.setup_mappers() - def setup_mappers(self): mapper(User, users, properties={ 'addresses':relation(Address, backref='user'), diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 8a30287996..1d7062c882 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -648,11 +648,15 @@ class ORMTest(TestBase, AssertsExecutionResults): _otest_metadata.bind = config.db self.define_tables(_otest_metadata) _otest_metadata.create_all() + self.setup_mappers() self.insert_data() def define_tables(self, _otest_metadata): raise NotImplementedError() - + + def setup_mappers(self): + pass + def insert_data(self): pass