]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- The versioning example now supports detection of changes
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Dec 2010 15:44:31 +0000 (10:44 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Dec 2010 15:44:31 +0000 (10:44 -0500)
in an associated relationship().

CHANGES
examples/versioning/history_meta.py
examples/versioning/test_versioning.py

diff --git a/CHANGES b/CHANGES
index 407d40d226888f933cc2f630363e3ee9bfc64a09..c360f1606f8d0ddc856950e8bdd85d226995f544 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -156,6 +156,10 @@ CHANGES
     internal "cache" dictionary.  Particularly since the
     join() and select() objects are created in the method
     itself this was pretty much a pure memory leaking behavior.
+
+- examples
+  - The versioning example now supports detection of changes
+    in an associated relationship().
     
 0.6.5
 =====
index 0a631e8492ed7700dba47af6bc1b91a60286149a..fa95733e29c548a5bcb0f04d4bd98bab376c826d 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy.orm import mapper, class_mapper, attributes, object_mapper
 from sqlalchemy.orm.exc import UnmappedClassError, UnmappedColumnError
 from sqlalchemy import Table, Column, ForeignKeyConstraint, Integer
 from sqlalchemy.orm.interfaces import SessionExtension
+from sqlalchemy.orm.properties import RelationshipProperty
 
 def col_references_table(col, table):
     for fk in col.foreign_keys:
@@ -147,7 +148,16 @@ def create_version(obj, session, deleted = False):
                 # if the attribute had no value.
                 attr[hist_col.key] = a[0]
                 obj_changed = True
-                
+    
+    if not obj_changed:
+        # not changed, but we have relationships.  OK
+        # check those too
+        for prop in obj_mapper.iterate_properties:
+            if isinstance(prop, RelationshipProperty) and \
+                attributes.get_history(obj, prop.key).has_changes():
+                obj_changed = True
+                break
+        
     if not obj_changed and not deleted:            
         return
 
index 031d7ca261bc834698c9bc8c1422cbcf122c4653..47e556b0a046509de5da4ecc5d4ca576c9ed5445 100644 (file)
@@ -1,7 +1,8 @@
 from sqlalchemy.ext.declarative import declarative_base
 from history_meta import VersionedMeta, VersionedListener
 from sqlalchemy import create_engine, Column, Integer, String, ForeignKey
-from sqlalchemy.orm import clear_mappers, compile_mappers, sessionmaker, deferred
+from sqlalchemy.orm import clear_mappers, compile_mappers, \
+    sessionmaker, deferred, relationship
 from sqlalchemy.test.testing import TestBase, eq_
 from sqlalchemy.test.entities import ComparableEntity
 
@@ -11,8 +12,11 @@ def setup():
     
 class TestVersioning(TestBase):
     def setup(self):
-        global Base, Session
-        Base = declarative_base(metaclass=VersionedMeta, bind=engine)
+        global Base, Session, Versioned
+        Base = declarative_base(bind=engine)
+        class Versioned(object):
+            __metaclass__ = VersionedMeta
+            _decl_class_registry = Base._decl_class_registry
         Session = sessionmaker(extension=VersionedListener())
         
     def teardown(self):
@@ -23,7 +27,7 @@ class TestVersioning(TestBase):
         Base.metadata.create_all()
         
     def test_plain(self):
-        class SomeClass(Base, ComparableEntity):
+        class SomeClass(Versioned, Base, ComparableEntity):
             __tablename__ = 'sometable'
             
             id = Column(Integer, primary_key=True)
@@ -87,7 +91,7 @@ class TestVersioning(TestBase):
         )
 
     def test_from_null(self):
-        class SomeClass(Base, ComparableEntity):
+        class SomeClass(Versioned, Base, ComparableEntity):
             __tablename__ = 'sometable'
             
             id = Column(Integer, primary_key=True)
@@ -107,7 +111,7 @@ class TestVersioning(TestBase):
     def test_deferred(self):
         """test versioning of unloaded, deferred columns."""
         
-        class SomeClass(Base, ComparableEntity):
+        class SomeClass(Versioned, Base, ComparableEntity):
             __tablename__ = 'sometable'
 
             id = Column(Integer, primary_key=True)
@@ -138,7 +142,7 @@ class TestVersioning(TestBase):
         
         
     def test_joined_inheritance(self):
-        class BaseClass(Base, ComparableEntity):
+        class BaseClass(Versioned, Base, ComparableEntity):
             __tablename__ = 'basetable'
 
             id = Column(Integer, primary_key=True)
@@ -215,7 +219,7 @@ class TestVersioning(TestBase):
         )
 
     def test_single_inheritance(self):
-        class BaseClass(Base, ComparableEntity):
+        class BaseClass(Versioned, Base, ComparableEntity):
             __tablename__ = 'basetable'
 
             id = Column(Integer, primary_key=True)
@@ -261,7 +265,7 @@ class TestVersioning(TestBase):
         )
     
     def test_unique(self):
-        class SomeClass(Base, ComparableEntity):
+        class SomeClass(Versioned, Base, ComparableEntity):
             __tablename__ = 'sometable'
             
             id = Column(Integer, primary_key=True)
@@ -284,3 +288,51 @@ class TestVersioning(TestBase):
         
         assert sc.version == 3
 
+    def test_relationship(self):
+
+        class SomeRelated(Base, ComparableEntity):
+            __tablename__ = 'somerelated'
+            
+            id = Column(Integer, primary_key=True)
+
+        class SomeClass(Versioned, Base, ComparableEntity):
+            __tablename__ = 'sometable'
+            
+            id = Column(Integer, primary_key=True)
+            name = Column(String(50))
+            related_id = Column(Integer, ForeignKey('somerelated.id'))
+            related = relationship("SomeRelated")
+            
+        SomeClassHistory = SomeClass.__history_mapper__.class_
+            
+        self.create_tables()
+        sess = Session()
+        sc = SomeClass(name='sc1')
+        sess.add(sc)
+        sess.commit()
+
+        assert sc.version == 1
+        
+        sr1 = SomeRelated()
+        sc.related = sr1
+        sess.commit()
+        
+        assert sc.version == 2
+        
+        eq_(
+            sess.query(SomeClassHistory).filter(SomeClassHistory.version == 1).all(),
+            [SomeClassHistory(version=1, name='sc1', related_id=None)]
+        )
+
+        sc.related = None
+
+        eq_(
+            sess.query(SomeClassHistory).order_by(SomeClassHistory.version).all(),
+            [
+                SomeClassHistory(version=1, name='sc1', related_id=None),
+                SomeClassHistory(version=2, name='sc1', related_id=sr1.id)
+            ]
+        )
+
+        assert sc.version == 3
+