]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- modernized cascade.py tests
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Feb 2008 18:13:14 +0000 (18:13 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Feb 2008 18:13:14 +0000 (18:13 +0000)
- your cries have been heard:  removing a pending item
from an attribute or collection with delete-orphan
expunges the item from the session; no FlushError is raised.
Note that if you session.save()'ed the pending item
explicitly, the attribute/collection removal still knocks
it out.

CHANGES
lib/sqlalchemy/orm/unitofwork.py
test/orm/cascade.py
test/orm/query.py
test/testlib/testing.py

diff --git a/CHANGES b/CHANGES
index 56b3599fb2af9a0eeaac477446fffc5bc6c7ba82..cd8c1de4a7ba812492f7d91fa77af19794c71441 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -8,6 +8,12 @@ CHANGES
       work properly with self-referential relations - the clause
       inside the EXISTS is aliased on the "remote" side to
       distinguish it from the parent table.
+    - your cries have been heard:  removing a pending item
+      from an attribute or collection with delete-orphan 
+      expunges the item from the session; no FlushError is raised.  
+      Note that if you session.save()'ed the pending item 
+      explicitly, the attribute/collection removal still knocks 
+      it out.
       
 0.4.4
 ------
index d9131da41671d0d139232f09f5c8e84adfb67f43..0f87d730878c4027a0a7280dd81385a88e1c3bb3 100644 (file)
@@ -33,7 +33,7 @@ class UOWEventHandler(interfaces.AttributeExtension):
     session cascade operations.
     """
 
-    def __init__(self, key, class_, cascade=None):
+    def __init__(self, key, class_, cascade):
         self.key = key
         self.class_ = class_
         self.cascade = cascade
@@ -41,27 +41,34 @@ class UOWEventHandler(interfaces.AttributeExtension):
     def append(self, obj, item, initiator):
         # process "save_update" cascade rules for when an instance is appended to the list of another instance
         sess = object_session(obj)
-        if sess is not None:
-            if self.cascade is not None and self.cascade.save_update and item not in sess:
+        if sess:
+            if self.cascade.save_update and item not in sess:
                 mapper = object_mapper(obj)
                 prop = mapper.get_property(self.key)
                 ename = prop.mapper.entity_name
                 sess.save_or_update(item, entity_name=ename)
 
     def remove(self, obj, item, initiator):
-        # currently no cascade rules for removing an item from a list
-        # (i.e. it stays in the Session)
-        pass
+        sess = object_session(obj)
+        if sess:
+            # expunge pending orphans
+            if self.cascade.delete_orphan and item in sess.new:
+                sess.expunge(item)
 
     def set(self, obj, newvalue, oldvalue, initiator):
         # process "save_update" cascade rules for when an instance is attached to another instance
+        if oldvalue is newvalue:
+            return
         sess = object_session(obj)
-        if sess is not None:
-            if newvalue is not None and self.cascade is not None and self.cascade.save_update and newvalue not in sess:
+        if sess:
+            if newvalue is not None and self.cascade.save_update and newvalue not in sess:
                 mapper = object_mapper(obj)
                 prop = mapper.get_property(self.key)
                 ename = prop.mapper.entity_name
                 sess.save_or_update(newvalue, entity_name=ename)
+            if self.cascade.delete_orphan and oldvalue in sess.new:
+                sess.expunge(oldvalue)
+
 
 def register_attribute(class_, key, *args, **kwargs):
     """overrides attributes.register_attribute() to add UOW event handlers
index 7222655776fb98b28c47e7c3c45ce0074914a7dc..e2b2c61227695789865ab0f106cfdbda956257c0 100644 (file)
@@ -3,106 +3,44 @@ import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy import exceptions
 from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
 from testlib import *
-import testlib.tables as tables
+from testlib import fixtures
 
-class O2MCascadeTest(TestBase, AssertsExecutionResults):
-    def tearDown(self):
-        tables.delete()
+class O2MCascadeTest(fixtures.FixtureTest):
+    keep_mappers = True
+    keep_data = False
+    refresh_data = False
 
-    def tearDownAll(self):
-        clear_mappers()
-        tables.drop()
-
-    def setUpAll(self):
-        global data
-        tables.create()
-        mapper(tables.User, tables.users, properties = dict(
-            address = relation(mapper(tables.Address, tables.addresses), lazy=True, uselist = False, cascade="all, delete-orphan"),
+    def setup_mappers(self):
+        global User, Address, Order, users, orders, addresses
+        from testlib.fixtures import User, Address, Order, users, orders, addresses
+                
+        mapper(Address, addresses)
+        mapper(User, users, properties = dict(
+            addresses = relation(Address, cascade="all, delete-orphan"),
             orders = relation(
-                mapper(tables.Order, tables.orders, properties = dict (
-                    items = relation(mapper(tables.Item, tables.orderitems), lazy=True, uselist =True, cascade="all, delete-orphan")
-                )),
-                lazy = True, uselist = True, cascade="all, delete-orphan")
+                mapper(Order, orders), cascade="all, delete-orphan")
         ))
-
-    def setUp(self):
-        global data
-        data = [tables.User,
-            {'user_name' : 'ed',
-                'address' : (tables.Address, {'email_address' : 'foo@bar.com'}),
-                'orders' : (tables.Order, [
-                    {'description' : 'eds 1st order', 'items' : (tables.Item, [{'item_name' : 'eds o1 item'}, {'item_name' : 'eds other o1 item'}])},
-                    {'description' : 'eds 2nd order', 'items' : (tables.Item, [{'item_name' : 'eds o2 item'}, {'item_name' : 'eds other o2 item'}])}
-                 ])
-            },
-            {'user_name' : 'jack',
-                'address' : (tables.Address, {'email_address' : 'jack@jack.com'}),
-                'orders' : (tables.Order, [
-                    {'description' : 'jacks 1st order', 'items' : (tables.Item, [{'item_name' : 'im a lumberjack'}, {'item_name' : 'and im ok'}])}
-                 ])
-            },
-            {'user_name' : 'foo',
-                'address' : (tables.Address, {'email_address': 'hi@lala.com'}),
-                'orders' : (tables.Order, [
-                    {'description' : 'foo order', 'items' : (tables.Item, [])},
-                    {'description' : 'foo order 2', 'items' : (tables.Item, [{'item_name' : 'hi'}])},
-                    {'description' : 'foo order three', 'items' : (tables.Item, [{'item_name' : 'there'}])}
-                ])
-            }
-        ]
-
+    
+    def test_list_assignment(self):
         sess = create_session()
-        for elem in data[1:]:
-            u = tables.User()
-            sess.save(u)
-            u.user_name = elem['user_name']
-            u.address = tables.Address()
-            u.address.email_address = elem['address'][1]['email_address']
-            u.orders = []
-            for order in elem['orders'][1]:
-                o = tables.Order()
-                o.isopen = None
-                o.description = order['description']
-                u.orders.append(o)
-                o.items = []
-                for item in order['items'][1]:
-                    i = tables.Item()
-                    i.item_name = item['item_name']
-                    o.items.append(i)
-
-        sess.flush()
-        sess.clear()
-
-    def testassignlist(self):
-        sess = create_session()
-        u = tables.User()
-        u.user_name = 'jack'
-        o1 = tables.Order()
-        o1.description ='someorder'
-        o2 = tables.Order()
-        o2.description = 'someotherorder'
-        l = [o1, o2]
+        u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')])
         sess.save(u)
-        u.orders = l
-        assert o1 in sess
-        assert o2 in sess
         sess.flush()
         sess.clear()
-
-        u = sess.query(tables.User).get(u.user_id)
-        o3 = tables.Order()
-        o3.description='order3'
-        o4 = tables.Order()
-        o4.description = 'order4'
-        u.orders = [o3, o4]
-        assert o3 in sess
-        assert o4 in sess
+        
+        u = sess.query(User).get(u.id)
+        self.assertEquals(u, User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')]))
+        
+        u.orders=[Order(description="order 3"), Order(description="order 4")]
         sess.flush()
+        sess.clear()
+        
+        u = sess.query(User).get(u.id)
+        self.assertEquals(u, User(name='jack', orders=[Order(description="order 3"), Order(description="order 4")]))
 
-        o5 = tables.Order()
-        o5.description='order5'
+        self.assertEquals(sess.query(Order).all(), [Order(description="order 3"), Order(description="order 4")])
+        o5 = Order(description="order 5")
         sess.save(o5)
         try:
             sess.flush()
@@ -110,108 +48,106 @@ class O2MCascadeTest(TestBase, AssertsExecutionResults):
         except exceptions.FlushError, e:
             assert "is an orphan" in str(e)
 
-
-    def testdelete(self):
+    def test_delete(self):
         sess = create_session()
-        l = sess.query(tables.User).all()
-        for u in l:
-            print repr(u.orders)
-        self.assert_result(l, data[0], *data[1:])
-
-        ids = (l[0].user_id, l[2].user_id)
-        sess.delete(l[0])
-        sess.delete(l[2])
+        u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')])
+        sess.save(u)
+        sess.flush()
 
+        sess.delete(u)
         sess.flush()
-        assert tables.orders.count(tables.orders.c.user_id.in_(ids)).scalar() == 0
-        assert tables.orderitems.count(tables.orders.c.user_id.in_(ids)  &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0
-        assert tables.addresses.count(tables.addresses.c.user_id.in_(ids)).scalar() == 0
-        assert tables.users.count(tables.users.c.user_id.in_(ids)).scalar() == 0
+        assert users.count().scalar() == 0
+        assert orders.count().scalar() == 0
 
-    def testdelete2(self):
+    def test_delete_unloaded_collections(self):
         """test that unloaded collections are still included in a delete-cascade by default."""
 
         sess = create_session()
-        u = sess.query(tables.User).filter_by(user_name='ed').one()
-        # assert 'addresses' collection not loaded
+        u = User(name='jack', addresses=[Address(email_address="address1"), Address(email_address="address2")])
+        sess.save(u)
+        sess.flush()
+        sess.clear()
+        assert addresses.count().scalar() == 2
+        assert users.count().scalar() == 1
+        
+        u = sess.query(User).get(u.id)
+        
         assert 'addresses' not in u.__dict__
         sess.delete(u)
         sess.flush()
-        assert tables.addresses.count(tables.addresses.c.email_address=='foo@bar.com').scalar() == 0
-        assert tables.orderitems.count(tables.orderitems.c.item_name.like('eds%')).scalar() == 0
+        assert addresses.count().scalar() == 0
+        assert users.count().scalar() == 0
 
-    def testcascadecollection(self):
+    def test_cascades_onlycollection(self):
         """test that cascade only reaches instances that are still part of the collection,
         not those that have been removed"""
-        sess = create_session()
 
-        u = tables.User()
-        u.user_name = 'newuser'
-        o = tables.Order()
-        o.description = "some description"
-        u.orders.append(o)
+        sess = create_session()
+        u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')])
         sess.save(u)
         sess.flush()
-
-        u.orders.remove(o)
+        
+        o = u.orders[0]
+        del u.orders[0]
         sess.delete(u)
         assert u in sess.deleted
         assert o not in sess.deleted
+        assert o in sess
 
-
-    def testorphan(self):
+        u2 = User(name='newuser', orders=[o])
+        sess.save(u2)
+        sess.flush()
+        sess.clear()
+        assert users.count().scalar() == 1
+        assert orders.count().scalar() == 1
+        self.assertEquals(sess.query(User).all(), [User(name='newuser', orders=[Order(description='someorder')])])
+            
+    def test_collection_orphans(self):
         sess = create_session()
-        l = sess.query(tables.User).all()
-        jack = l[1]
-        jack.orders[:] = []
+        u = User(name='jack', orders=[Order(description='someorder'), Order(description='someotherorder')])
+        sess.save(u)
+        sess.flush()
 
-        ids = [jack.user_id]
-        self.assert_(tables.orders.count(tables.orders.c.user_id.in_(ids)).scalar() == 1)
-        self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(ids)  &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 2)
+        assert users.count().scalar() == 1
+        assert orders.count().scalar() == 2
 
-        sess.flush()
+        u.orders[:] = []
 
-        self.assert_(tables.orders.count(tables.orders.c.user_id.in_(ids)).scalar() == 0)
-        self.assert_(tables.orderitems.count(tables.orders.c.user_id.in_(ids)  &(tables.orderitems.c.order_id==tables.orders.c.order_id)).scalar() == 0)
+        sess.flush()
 
+        assert users.count().scalar() == 1
+        assert orders.count().scalar() == 0
 
-class M2OCascadeTest(TestBase, AssertsExecutionResults):
-    def tearDown(self):
-        ctx.current.clear()
-        for t in metadata.table_iterator(reverse=True):
-            t.delete().execute()
 
-    def tearDownAll(self):
-        clear_mappers()
-        metadata.drop_all()
+class M2OCascadeTest(ORMTest):
+    keep_mappers = True
+    
+    def define_tables(self, metadata):
+        global extra, prefs, users
 
-    @testing.uses_deprecated('SessionContext')
-    def setUpAll(self):
-        global ctx, data, metadata, User, Pref, Extra
-        ctx = SessionContext(create_session)
-        metadata = MetaData(testing.db)
         extra = Table("extra", metadata,
-            Column("extra_id", Integer, Sequence("extra_id_seq", optional=True), primary_key=True),
-            Column("prefs_id", Integer, ForeignKey("prefs.prefs_id"))
+            Column("id", Integer, Sequence("extra_id_seq", optional=True), primary_key=True),
+            Column("prefs_id", Integer, ForeignKey("prefs.id"))
         )
         prefs = Table('prefs', metadata,
-            Column('prefs_id', Integer, Sequence('prefs_id_seq', optional=True), primary_key=True),
-            Column('prefs_data', String(40)))
+            Column('id', Integer, Sequence('prefs_id_seq', optional=True), primary_key=True),
+            Column('data', String(40)))
 
         users = Table('users', metadata,
-            Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
-            Column('user_name', String(40)),
-            Column('pref_id', Integer, ForeignKey('prefs.prefs_id'))
+            Column('id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
+            Column('name', String(40)),
+            Column('pref_id', Integer, ForeignKey('prefs.id'))
         )
-        class User(object):
-            def __init__(self, name):
-                self.user_name = name
-        class Pref(object):
-            def __init__(self, data):
-                self.prefs_data = data
-        class Extra(object):
+        
+    def setup_mappers(self):
+        global User, Pref, Extra
+        class User(fixtures.Base):
             pass
-        metadata.create_all()
+        class Pref(fixtures.Base):
+            pass
+        class Extra(fixtures.Base):
+            pass
+
         mapper(Extra, extra)
         mapper(Pref, prefs, properties=dict(
             extra = relation(Extra, cascade="all, delete")
@@ -221,64 +157,73 @@ class M2OCascadeTest(TestBase, AssertsExecutionResults):
         ))
 
     def setUp(self):
-        u1 = User("ed")
-        u1.pref = Pref("pref 1")
-        u2 = User("jack")
-        u2.pref = Pref("pref 2")
-        u3 = User("foo")
-        u3.pref = Pref("pref 3")
-        u1.pref.extra.append(Extra())
-        u2.pref.extra.append(Extra())
-        u2.pref.extra.append(Extra())
-
-        ctx.current.save(u1)
-        ctx.current.save(u2)
-        ctx.current.save(u3)
-        ctx.current.flush()
-        ctx.current.clear()
+        u1 = User(name='ed', pref=Pref(data="pref 1", extra=[Extra()]))
+        u2 = User(name='jack', pref=Pref(data="pref 2", extra=[Extra()]))
+        u3 = User(name="foo", pref=Pref(data="pref 3", extra=[Extra()]))
+        sess = create_session()
+        sess.save(u1)
+        sess.save(u2)
+        sess.save(u3)
+        sess.flush()
+        sess.close()
 
     @testing.fails_on('maxdb')
-    def testorphan(self):
-        jack = ctx.current.query(User).filter_by(user_name='jack').one()
-        p = jack.pref
-        e = jack.pref.extra[0]
+    def test_orphan(self):
+        sess = create_session()
+        assert prefs.count().scalar() == 3
+        assert extra.count().scalar() == 3
+        jack = sess.query(User).filter_by(name="jack").one()
         jack.pref = None
-        ctx.current.flush()
-        assert p not in ctx.current
-        assert e not in ctx.current
+        sess.flush()
+        assert prefs.count().scalar() == 2
+        assert extra.count().scalar() == 2
 
     @testing.fails_on('maxdb')
-    def testorphan2(self):
-        jack = ctx.current.query(User).filter_by(user_name='jack').one()
+    def test_orphan_on_update(self):
+        sess = create_session()
+        jack = sess.query(User).filter_by(name="jack").one()
         p = jack.pref
         e = jack.pref.extra[0]
-        ctx.current.clear()
+        sess.clear()
 
         jack.pref = None
-        ctx.current.update(jack)
-        ctx.current.update(p)
-        ctx.current.update(e)
-        assert p in ctx.current
-        assert e in ctx.current
-        ctx.current.flush()
-        assert p not in ctx.current
-        assert e not in ctx.current
-
-    def testorphan3(self):
+        sess.update(jack)
+        sess.update(p)
+        sess.update(e)
+        assert p in sess
+        assert e in sess
+        sess.flush()
+        assert prefs.count().scalar() == 2
+        assert extra.count().scalar() == 2
+    
+    def test_pending_expunge(self):
+        sess = create_session()
+        someuser = User(name='someuser')
+        sess.save(someuser)
+        sess.flush()
+        someuser.pref = p1 = Pref(data='somepref')
+        assert p1 in sess
+        someuser.pref = Pref(data='someotherpref')
+        assert p1 not in sess
+        sess.flush()
+        self.assertEquals(sess.query(Pref).with_parent(someuser).all(), [Pref(data="someotherpref")])
+
+        
+    def test_double_assignment(self):
         """test that double assignment doesn't accidentally reset the 'parent' flag."""
 
-        jack = ctx.current.query(User).filter_by(user_name='jack').one()
-        newpref = Pref("newpref")
+        sess = create_session()
+        jack = sess.query(User).filter_by(name="jack").one()
+
+        newpref = Pref(data="newpref")
         jack.pref = newpref
         jack.pref = newpref
-        ctx.current.flush()
-
-
+        sess.flush()
+        self.assertEquals(sess.query(Pref).all(), [Pref(data="pref 1"), Pref(data="pref 3"), Pref(data="newpref")])
 
-class M2MCascadeTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
-        global metadata, a, b, atob
-        metadata = MetaData(testing.db)
+class M2MCascadeTest(ORMTest):
+    def define_tables(self, metadata):
+        global a, b, atob
         a = Table('a', metadata,
             Column('id', Integer, primary_key=True),
             Column('data', String(30))
@@ -292,18 +237,12 @@ class M2MCascadeTest(TestBase, AssertsExecutionResults):
             Column('bid', Integer, ForeignKey('b.id'))
 
             )
-        metadata.create_all()
-
-    def tearDownAll(self):
-        metadata.drop_all()
 
-    def testdeleteorphan(self):
-        class A(object):
-            def __init__(self, data):
-                self.data = data
-        class B(object):
-            def __init__(self, data):
-                self.data = data
+    def test_delete_orphan(self):
+        class A(fixtures.Base):
+            pass
+        class B(fixtures.Base):
+            pass
 
         mapper(A, a, properties={
             # if no backref here, delete-orphan failed until [ticket:427] was fixed
@@ -312,9 +251,8 @@ class M2MCascadeTest(TestBase, AssertsExecutionResults):
         mapper(B, b)
 
         sess = create_session()
-        a1 = A('a1')
-        b1 = B('b1')
-        a1.bs.append(b1)
+        b1 = B(data='b1')
+        a1 = A(data='a1', bs=[b1])
         sess.save(a1)
         sess.flush()
 
@@ -324,13 +262,11 @@ class M2MCascadeTest(TestBase, AssertsExecutionResults):
         assert b.count().scalar() == 0
         assert a.count().scalar() == 1
 
-    def testcascadedelete(self):
-        class A(object):
-            def __init__(self, data):
-                self.data = data
-        class B(object):
-            def __init__(self, data):
-                self.data = data
+    def test_cascade_delete(self):
+        class A(fixtures.Base):
+            pass
+        class B(fixtures.Base):
+            pass
 
         mapper(A, a, properties={
             'bs':relation(B, secondary=atob, cascade="all, delete-orphan")
@@ -338,9 +274,7 @@ class M2MCascadeTest(TestBase, AssertsExecutionResults):
         mapper(B, b)
 
         sess = create_session()
-        a1 = A('a1')
-        b1 = B('b1')
-        a1.bs.append(b1)
+        a1 = A(data='a1', bs=[B(data='b1')])
         sess.save(a1)
         sess.flush()
 
@@ -357,7 +291,7 @@ class UnsavedOrphansTest(ORMTest):
         global users, addresses, User, Address
         users = Table('users', metadata,
             Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
-            Column('user_name', String(40)),
+            Column('name', String(40)),
         )
 
         addresses = Table('email_addresses', metadata,
@@ -365,10 +299,10 @@ class UnsavedOrphansTest(ORMTest):
             Column('user_id', Integer, ForeignKey(users.c.user_id)),
             Column('email_address', String(40)),
         )
-        class User(object):pass
-        class Address(object):pass
+        class User(fixtures.Base):pass
+        class Address(fixtures.Base):pass
 
-    def test_pending_orphan(self):
+    def test_pending_standalone_orphan(self):
         """test that an entity that never had a parent on a delete-orphan cascade cant be saved."""
 
         mapper(Address, addresses)
@@ -384,9 +318,8 @@ class UnsavedOrphansTest(ORMTest):
             pass
         assert a.address_id is None, "Error: address should not be persistent"
 
-    def test_delete_new_object(self):
-        """test that an entity which is attached then detached from its
-        parent with a delete-orphan cascade gets counted as an orphan"""
+    def test_pending_collection_expunge(self):
+        """test that removing a pending item from a collection expunges it from the session."""
 
         mapper(Address, addresses)
         mapper(User, users, properties=dict(
@@ -398,18 +331,33 @@ class UnsavedOrphansTest(ORMTest):
         s.save(u)
         s.flush()
         a = Address()
-        assert a not in s.new
+
         u.addresses.append(a)
+        assert a in s
+        
         u.addresses.remove(a)
-        s.delete(u)
-        try:
-            s.flush() # (erroneously) causes "a" to be persisted
-            assert False
-        except exceptions.FlushError:
-            assert True
-        assert a.address_id is None, "Error: address should not be persistent"
+        assert a not in s
 
+        s.delete(u)
+        s.flush() 
 
+        assert a.address_id is None, "Error: address should not be persistent"
+    
+    def test_nonorphans_ok(self):
+        mapper(Address, addresses)
+        mapper(User, users, properties=dict(
+            addresses=relation(Address, cascade="all,delete", backref="user")
+        ))
+        s = create_session()
+        u = User(name='u1', addresses=[Address(email_address='ad1')])
+        s.save(u)
+        a1 = u.addresses[0]
+        u.addresses.remove(a1)
+        assert a1 in s
+        s.flush()
+        s.clear()
+        self.assertEquals(s.query(Address).all(), [Address(email_address='ad1')])
+        
 class UnsavedOrphansTest2(ORMTest):
     """same test as UnsavedOrphans only three levels deep"""
 
@@ -433,56 +381,52 @@ class UnsavedOrphansTest2(ORMTest):
 
         )
 
-    def testdeletechildwithchild(self):
-        """test that an entity which is attached then detached from its
-        parent with a delete-orphan cascade gets counted as an orphan, as well
-        as its own child instances"""
-
-        class Order(object): pass
-        class Item(object): pass
-        class Attribute(object): pass
+    def test_pending_expunge(self):
+        class Order(fixtures.Base):
+            pass
+        class Item(fixtures.Base):
+            pass
+        class Attribute(fixtures.Base):
+            pass
 
-        attrMapper = mapper(Attribute, attributes)
-        itemMapper = mapper(Item, items, properties=dict(
-            attributes=relation(attrMapper, cascade="all,delete-orphan", backref="item")
+        mapper(Attribute, attributes)
+        mapper(Item, items, properties=dict(
+            attributes=relation(Attribute, cascade="all,delete-orphan", backref="item")
         ))
-        orderMapper = mapper(Order, orders, properties=dict(
-            items=relation(itemMapper, cascade="all,delete-orphan", backref="order")
+        mapper(Order, orders, properties=dict(
+            items=relation(Item, cascade="all,delete-orphan", backref="order")
         ))
 
-        s = create_session( )
-        order = Order()
+        s = create_session()
+        order = Order(name="order1")
         s.save(order)
 
-        item = Item()
-        attr = Attribute()
-        item.attributes.append(attr)
+        attr = Attribute(name="attr1")
+        item = Item(name="item1", attributes=[attr])
 
         order.items.append(item)
-        order.items.remove(item) # item is an orphan, but attr is not so flush() tries to save attr
-        try:
-            s.flush()
-            assert False
-        except exceptions.FlushError, e:
-            print e
-            assert True
-
-        assert item.id is None
-        assert attr.id is None
+        order.items.remove(item) 
+        
+        assert item not in s
+        assert attr not in s
+        
+        s.flush()
+        assert orders.count().scalar() == 1
+        assert items.count().scalar() == 0
+        assert attributes.count().scalar() == 0
 
-class DoubleParentOrphanTest(TestBase, AssertsExecutionResults):
+class DoubleParentOrphanTest(ORMTest):
     """test orphan detection for an entity with two parent relations"""
 
-    def setUpAll(self):
-        global metadata, address_table, businesses, homes
-        metadata = MetaData(testing.db)
+    def define_tables(self, metadata):
+        global address_table, businesses, homes
         address_table = Table('addresses', metadata,
             Column('address_id', Integer, primary_key=True),
             Column('street', String(30)),
         )
 
         homes = Table('homes', metadata,
-            Column('home_id', Integer, primary_key=True),
+            Column('home_id', Integer, primary_key=True, key="id"),
             Column('description', String(30)),
             Column('address_id', Integer, ForeignKey('addresses.address_id'), nullable=False),
         )
@@ -492,38 +436,42 @@ class DoubleParentOrphanTest(TestBase, AssertsExecutionResults):
             Column('description', String(30), key="description"),
             Column('address_id', Integer, ForeignKey('addresses.address_id'), nullable=False),
         )
-        metadata.create_all()
-    def tearDown(self):
-        clear_mappers()
-    def tearDownAll(self):
-        metadata.drop_all()
+        
     def test_non_orphan(self):
         """test that an entity can have two parent delete-orphan cascades, and persists normally."""
 
-        class Address(object):pass
-        class Home(object):pass
-        class Business(object):pass
+        class Address(fixtures.Base):
+            pass
+        class Home(fixtures.Base):
+            pass
+        class Business(fixtures.Base):
+            pass
+        
         mapper(Address, address_table)
         mapper(Home, homes, properties={'address':relation(Address, cascade="all,delete-orphan")})
         mapper(Business, businesses, properties={'address':relation(Address, cascade="all,delete-orphan")})
 
         session = create_session()
-        a1 = Address()
-        a2 = Address()
-        h1 = Home()
-        b1 = Business()
-        h1.address = a1
-        b1.address = a2
+        h1 = Home(description='home1', address=Address(street='address1'))
+        b1 = Business(description='business1', address=Address(street='address2'))
         [session.save(x) for x in [h1,b1]]
         session.flush()
+        session.clear()
+        
+        self.assertEquals(session.query(Home).get(h1.id), Home(description='home1', address=Address(street='address1')))
+        self.assertEquals(session.query(Business).get(b1.id), Business(description='business1', address=Address(street='address2')))
 
     def test_orphan(self):
         """test that an entity can have two parent delete-orphan cascades, and is detected as an orphan
         when saved without a parent."""
 
-        class Address(object):pass
-        class Home(object):pass
-        class Business(object):pass
+        class Address(fixtures.Base):
+            pass
+        class Home(fixtures.Base):
+            pass
+        class Business(fixtures.Base):
+            pass
+        
         mapper(Address, address_table)
         mapper(Home, homes, properties={'address':relation(Address, cascade="all,delete-orphan")})
         mapper(Business, businesses, properties={'address':relation(Address, cascade="all,delete-orphan")})
@@ -537,60 +485,47 @@ class DoubleParentOrphanTest(TestBase, AssertsExecutionResults):
         except exceptions.FlushError, e:
             assert True
 
-class CollectionAssignmentOrphanTest(TestBase, AssertsExecutionResults):
-    def setUpAll(self):
-        global metadata, table_a, table_b
+class CollectionAssignmentOrphanTest(ORMTest):
+    def define_tables(self, metadata):
+        global table_a, table_b
 
-        metadata = MetaData(testing.db)
         table_a = Table('a', metadata,
                         Column('id', Integer, primary_key=True),
-                        Column('foo', String(30)))
+                        Column('name', String(30)))
         table_b = Table('b', metadata,
                         Column('id', Integer, primary_key=True),
-                        Column('foo', String(30)),
+                        Column('name', String(30)),
                         Column('a_id', Integer, ForeignKey('a.id')))
-        metadata.create_all()
-
-    def tearDown(self):
-        clear_mappers()
-    def tearDownAll(self):
-        metadata.drop_all()
 
     def test_basic(self):
-        class A(object):
-            def __init__(self, foo):
-                self.foo = foo
-        class B(object):
-            def __init__(self, foo):
-                self.foo = foo
+        class A(fixtures.Base):
+            pass
+        class B(fixtures.Base):
+            pass
 
         mapper(A, table_a, properties={
             'bs':relation(B, cascade="all, delete-orphan")
             })
         mapper(B, table_b)
 
-        a1 = A('a1')
-        a1.bs.append(B('b1'))
-        a1.bs.append(B('b2'))
-        a1.bs.append(B('b3'))
+        a1 = A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')])
 
         sess = create_session()
         sess.save(a1)
         sess.flush()
 
-        assert table_b.count(table_b.c.a_id == None).scalar() == 0
-
-        assert table_b.count().scalar() == 3
+        sess.clear()
+        
+        self.assertEquals(sess.query(A).get(a1.id), A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')]))
 
         a1 = sess.query(A).get(a1.id)
-        assert len(a1.bs) == 3
-        a1.bs = list(a1.bs)
         assert not class_mapper(B)._is_orphan(a1.bs[0])
         a1.bs[0].foo='b2modified'
         a1.bs[1].foo='b3modified'
         sess.flush()
 
-        assert table_b.count().scalar() == 3
+        sess.clear()
+        self.assertEquals(sess.query(A).get(a1.id), A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')]))
 
 if __name__ == "__main__":
     testenv.main()
index a2ac6cf4e09b3ead5f3480a70278761cdfdfa600..d48abd5d2ef7e9767b49fda4908468ecb175055b 100644 (file)
@@ -13,10 +13,6 @@ class QueryTest(FixtureTest):
     keep_mappers = True
     keep_data = True
 
-    def setUpAll(self):
-        super(QueryTest, self).setUpAll()
-        self.setup_mappers()
-
     def setup_mappers(self):
         mapper(User, users, properties={
             'addresses':relation(Address, backref='user'),
index 8a30287996330c90f4cdafed62688fb66f47e472..1d7062c882eb62220044a4c4a1b2c3c81da644cb 100644 (file)
@@ -648,11 +648,15 @@ class ORMTest(TestBase, AssertsExecutionResults):
                 _otest_metadata.bind = config.db
         self.define_tables(_otest_metadata)
         _otest_metadata.create_all()
+        self.setup_mappers()
         self.insert_data()
 
     def define_tables(self, _otest_metadata):
         raise NotImplementedError()
-
+    
+    def setup_mappers(self):
+        pass
+        
     def insert_data(self):
         pass