]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- modernize MutableTypesTest, add tests for expired/deferred to establish 0.6 behavior
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Nov 2010 23:49:06 +0000 (18:49 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Nov 2010 23:49:06 +0000 (18:49 -0500)
regarding [ticket:1976]

test/orm/test_relationships.py
test/orm/test_unitofwork.py

index 555389a09126e80e87df6dea37d084695f9ba971..5033a84fe6d34d5e55d0a859e8be61fb1ab52a8b 100644 (file)
@@ -2377,7 +2377,8 @@ class ActiveHistoryFlagTest(_fixtures.FixtureTest):
     @testing.resolve_artifact_names
     def test_column_property_flag(self):
         mapper(User, users, properties={
-            'name':column_property(users.c.name, active_history=True)
+            'name':column_property(users.c.name, 
+                                active_history=True)
         })
         u1 = User(name='jack')
         self._test_attribute(u1, 'name', 'ed')
@@ -2395,8 +2396,6 @@ class ActiveHistoryFlagTest(_fixtures.FixtureTest):
     
     @testing.resolve_artifact_names
     def test_composite_property_flag(self):
-        # active_history is implicit for composites
-        # right now, no flag needed
         class MyComposite(object):
             def __init__(self, description, isopen):
                 self.description = description
index 52a93a122c8349b598ace6debe1bfc114e0d6295..312195bf618e2b44e58b7e96762b69735fa2af87 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy.test.schema import Column
 from sqlalchemy.orm import mapper, relationship, create_session, \
     column_property, attributes, Session, reconstructor, object_session
 from sqlalchemy.test.testing import eq_, ne_
+from sqlalchemy.test.util import gc_collect
 from test.orm import _base, _fixtures
 from test.engine import _base as engine_base
 from sqlalchemy.test.assertsql import AllOf, CompiledSQL
@@ -272,73 +273,76 @@ class MutableTypesTest(_base.MappedTest):
         mapper(Foo, mutable_t)
 
     @testing.resolve_artifact_names
-    def test_basic(self):
-        """Changes are detected for types marked as MutableType."""
-
-        f1 = Foo()
-        f1.data = pickleable.Bar(4,5)
-
-        session = create_session()
+    def test_modified_status(self):
+        f1 = Foo(data = pickleable.Bar(4,5))
+        
+        session = Session()
         session.add(f1)
-        session.flush()
-        session.expunge_all()
+        session.commit()
 
-        f2 = session.query(Foo).filter_by(id=f1.id).one()
+        f2 = session.query(Foo).first()
         assert 'data' in sa.orm.attributes.instance_state(f2).unmodified
         eq_(f2.data, f1.data)
 
         f2.data.y = 19
         assert f2 in session.dirty
         assert 'data' not in sa.orm.attributes.instance_state(f2).unmodified
-        session.flush()
-        session.expunge_all()
-
-        f3 = session.query(Foo).filter_by(id=f1.id).one()
+    
+    @testing.resolve_artifact_names
+    def test_mutations_persisted(self):
+        f1 = Foo(data = pickleable.Bar(4,5))
+        
+        session = Session()
+        session.add(f1)
+        session.commit()
+        f1.data
+        session.close()
+        
+        f2 = session.query(Foo).first()
+        f2.data.y = 19
+        session.commit()
+        f2.data
+        session.close()
+        
+        f3 = session.query(Foo).first()
         ne_(f3.data,f1.data)
         eq_(f3.data, pickleable.Bar(4, 19))
-
+        
     @testing.resolve_artifact_names
-    def test_mutable_changes(self):
-        """Mutable changes are detected or not detected correctly"""
-
-        f1 = Foo()
-        f1.data = pickleable.Bar(4,5)
-        f1.val = u'hi'
+    def test_no_unnecessary_update(self):
+        f1 = Foo(data = pickleable.Bar(4,5), val = u'hi')
 
-        session = create_session(autocommit=False)
+        session = Session()
         session.add(f1)
         session.commit()
 
-        bind = self.metadata.bind
-
         self.sql_count_(0, session.commit)
+        
         f1.val = u'someothervalue'
-        self.assert_sql(bind, session.commit, [
+        self.assert_sql(testing.db, session.commit, [
             ("UPDATE mutable_t SET val=:val "
              "WHERE mutable_t.id = :mutable_t_id",
              {'mutable_t_id': f1.id, 'val': u'someothervalue'})])
 
         f1.val = u'hi'
         f1.data.x = 9
-        self.assert_sql(bind, session.commit, [
+        self.assert_sql(testing.db, session.commit, [
             ("UPDATE mutable_t SET data=:data, val=:val "
              "WHERE mutable_t.id = :mutable_t_id",
              {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})])
-
+        
     @testing.resolve_artifact_names
-    def test_resurrect(self):
-        f1 = Foo()
-        f1.data = pickleable.Bar(4,5)
-        f1.val = u'hi'
+    def test_mutated_state_resurrected(self):
+        f1 = Foo(data = pickleable.Bar(4,5), val = u'hi')
 
-        session = create_session(autocommit=False)
+        session = Session()
         session.add(f1)
         session.commit()
 
         f1.data.y = 19
         del f1
 
-        gc.collect()
+        gc_collect()
         assert len(session.identity_map) == 1
 
         session.commit()
@@ -346,22 +350,12 @@ class MutableTypesTest(_base.MappedTest):
         assert session.query(Foo).one().data == pickleable.Bar(4, 19)
 
     @testing.resolve_artifact_names
-    def test_resurrect_two(self):
-        f1 = Foo()
-        f1.data = pickleable.Bar(4,5)
-        session = create_session(autocommit=False)
-        session.add(f1)
-        session.commit()
+    def test_mutated_plus_scalar_state_change_resurrected(self):
+        """test that a non-mutable attribute event subsequent to
+        a mutable event prevents the object from falling into
+        resurrected state.
         
-        session = create_session(autocommit=False)
-        f1 = session.query(Foo).first()
-        del f1 # modified flag flips by accident
-        gc.collect()
-        f1 = session.query(Foo).first()
-        assert not attributes.instance_state(f1).modified
-
-    @testing.resolve_artifact_names
-    def test_modified_after_mutable_change(self):
+         """
         f1 = Foo(data = pickleable.Bar(4, 5), val=u'some val')
         session = Session()
         session.add(f1)
@@ -378,22 +372,133 @@ class MutableTypesTest(_base.MappedTest):
             session.query(Foo.val).all(),
             [('some new val', )]
         )
+
+    @testing.resolve_artifact_names
+    def test_non_mutated_state_not_resurrected(self):
+        f1 = Foo(data = pickleable.Bar(4,5))
+        
+        session = Session()
+        session.add(f1)
+        session.commit()
         
+        session = Session()
+        f1 = session.query(Foo).first()
+        del f1
+        gc_collect()
+
+        assert len(session.identity_map) == 0
+        f1 = session.query(Foo).first()
+        assert not attributes.instance_state(f1).modified
+
     @testing.resolve_artifact_names
-    def test_unicode(self):
-        """Equivalent Unicode values are not flagged as changed."""
+    def test_scalar_no_net_change_no_update(self):
+        """Test that a no-net-change on a scalar attribute event
+        doesn't cause an UPDATE for a mutable state.
+        
+         """
 
         f1 = Foo(val=u'hi')
 
-        session = create_session(autocommit=False)
+        session = Session()
         session.add(f1)
         session.commit()
-        session.expunge_all()
+        session.close()
 
-        f1 = session.query(Foo).get(f1.id)
+        f1 = session.query(Foo).first()
         f1.val = u'hi'
         self.sql_count_(0, session.commit)
 
+    @testing.resolve_artifact_names
+    def test_expire_attribute_set(self):
+        """test one SELECT emitted when assigning to an expired
+        mutable attribute - this will become 0 in 0.7.
+        
+        """
+        
+        f1 = Foo(data = pickleable.Bar(4, 5), val=u'some val')
+        session = Session()
+        session.add(f1)
+        session.commit()
+        
+        assert 'data' not in f1.__dict__
+        def go():
+            f1.data = pickleable.Bar(10, 15)
+        self.sql_count_(1, go)
+        session.commit()
+        
+        eq_(f1.data.x, 10)
+
+    @testing.resolve_artifact_names
+    def test_expire_mutate(self):
+        """test mutations are detected on an expired mutable
+        attribute."""
+        
+        f1 = Foo(data = pickleable.Bar(4, 5), val=u'some val')
+        session = Session()
+        session.add(f1)
+        session.commit()
+        
+        assert 'data' not in f1.__dict__
+        def go():
+            f1.data.x = 10
+        self.sql_count_(1, go)
+        session.commit()
+        
+        eq_(f1.data.x, 10)
+        
+    @testing.resolve_artifact_names
+    def test_deferred_attribute_set(self):
+        """test one SELECT emitted when assigning to a deferred
+        mutable attribute - this will become 0 in 0.7.
+        
+        """
+        sa.orm.clear_mappers()
+        mapper(Foo, mutable_t, properties={
+            'data':sa.orm.deferred(mutable_t.c.data)
+        })
+
+        f1 = Foo(data = pickleable.Bar(4, 5), val=u'some val')
+        session = Session()
+        session.add(f1)
+        session.commit()
+        
+        session.close()
+        
+        f1 = session.query(Foo).first()
+        def go():
+            f1.data = pickleable.Bar(10, 15)
+        self.sql_count_(1, go)
+        session.commit()
+        
+        eq_(f1.data.x, 10)
+
+    @testing.resolve_artifact_names
+    def test_deferred_mutate(self):
+        """test mutations are detected on a deferred mutable
+        attribute."""
+        
+        sa.orm.clear_mappers()
+        mapper(Foo, mutable_t, properties={
+            'data':sa.orm.deferred(mutable_t.c.data)
+        })
+
+        f1 = Foo(data = pickleable.Bar(4, 5), val=u'some val')
+        session = Session()
+        session.add(f1)
+        session.commit()
+        
+        session.close()
+        
+        f1 = session.query(Foo).first()
+        def go():
+            f1.data.x = 10
+        self.sql_count_(1, go)
+        session.commit()
+        
+        def go():
+            eq_(f1.data.x, 10)
+        self.sql_count_(1, go)
+        
 
 class PickledDictsTest(_base.MappedTest):