]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added "pickleable" module to test suite to have cPickle-compatible
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Sep 2006 20:26:20 +0000 (20:26 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Sep 2006 20:26:20 +0000 (20:26 +0000)
test objects
- added copy_function, compare_function arguments to InstrumentedAttribute
- added MutableType mixin, copy_value/compare_values methods to TypeEngine,
PickleType
- ColumnProperty and DeferredProperty propigate the TypeEngine copy/compare
methods to the attribute instrumentation
- cleanup of UnitOfWork, removed unused methods
- UnitOfWork "dirty" list is calculated across the total collection of persistent
objects when called, no longer has register_dirty.
- attribute system can still report "modified" status fairly quickly, but does
extra work for InstrumentedAttributes that have detected a "mutable" type where
catching the __set__() event is not enough (i.e. PickleTypes)
- attribute tracking modified to be more intelligent about detecting
changes, particularly with mutable types.  TypeEngine objects now
take a greater role in defining how to compare two scalar instances,
including the addition of a MutableType mixin which is implemented by
PickleType.  unit-of-work now tracks the "dirty" list as an expression
of all persistent objects where the attribute manager detects changes.
The basic issue thats fixed is detecting changes on PickleType
objects, but also generalizes type handling and "modified" object
checking to be more complete and extensible.

14 files changed:
CHANGES
doc/build/content/unitofwork.txt
lib/sqlalchemy/attributes.py
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/types.py
test/base/attributes.py
test/orm/cycles.py
test/orm/unitofwork.py
test/pickleable.py [new file with mode: 0644]
test/sql/testtypes.py

diff --git a/CHANGES b/CHANGES
index 438b1993f609a07e55fa21784e438661513fe65f..e7954d9830642aadbbcd70d875f8431ec7179c70 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -40,6 +40,15 @@ explicitly set to False
 - connection pool tracks open cursors and raises an error if connection
 is returned to pool with cursors still opened.  fixes issues with MySQL, 
 others
+- attribute tracking modified to be more intelligent about detecting
+changes, particularly with mutable types.  TypeEngine objects now
+take a greater role in defining how to compare two scalar instances,
+including the addition of a MutableType mixin which is implemented by
+PickleType.  unit-of-work now tracks the "dirty" list as an expression
+of all persistent objects where the attribute manager detects changes.
+The basic issue thats fixed is detecting changes on PickleType 
+objects, but also generalizes type handling and "modified" object
+checking to be more complete and extensible.
 
 0.2.8
 - cleanup on connection methods + documentation.  custom DBAPI
index 8f49ad896a6132aa45f77bcf1130693b4758a248..3203bf10d5ba89ff96f3b0a36723bc3bfca602b2 100644 (file)
@@ -125,15 +125,16 @@ The `get()` method on `Query`, which retrieves an object based on primary key id
 
 ### Whats Changed ? {@name=changed}    
 
-The next concept is that in addition to the `Session` storing a record of all objects loaded or saved, it also stores lists of all *newly created* (i.e. pending) objects,  lists of all persistent objects whose attributes have been *modified*, and lists of all persistent objects that have been marked as *deleted*.  These lists are used when a `flush()` call is issued to save all changes.  After the flush occurs, these lists are all cleared out.
+The next concept is that in addition to the `Session` storing a record of all objects loaded or saved, it also stores lists of all *newly created* (i.e. pending) objects and lists of all persistent objects that have been marked as *deleted*.  These lists are used when a `flush()` call is issued to save all changes.  During a flush operation, it also scans its list of persistent instances for changes which are marked as dirty.
     
 These records are all tracked by a collection of `Set` objects (which are a SQLAlchemy-specific instance called a `HashSet`) that are also viewable off the `Session`:
 
     {python}
     # pending objects recently added to the Session
     session.new
-
-    # persistent objects with modifications
+    
+    # persistent objects which currently have changes detected
+    # (this Set is now created on the fly each time the property is called)
     session.dirty
 
     # persistent objects that have been marked as deleted via session.delete(obj)
index 84a1d58fb8741618cd63f6e7b4c3db5ecb722365..7fd9686e3568a181585974d4bdfefa3ff3146a9b 100644 (file)
@@ -13,13 +13,28 @@ class InstrumentedAttribute(object):
     
     PASSIVE_NORESULT = object()
     
-    def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, **kwargs):
+    def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
         self.manager = manager
         self.key = key
         self.uselist = uselist
         self.callable_ = callable_
         self.typecallable= typecallable
         self.trackparent = trackparent
+        if copy_function is None:
+            self._check_mutable_modified = False
+            if uselist:
+                self._copyfunc = lambda x: [y for y in x]
+            else:
+                # scalar values are assumed to be immutable unless a copy function
+                # is passed
+                self._copyfunc = lambda x: x
+        else:
+            self._check_mutable_modified = True
+            self._copyfunc = copy_function
+        if compare_function is None:
+            self._compare_function = lambda x,y: x == y
+        else:
+            self._compare_function = compare_function
         self.extensions = util.to_list(extension or [])
 
     def __set__(self, obj, value):
@@ -31,6 +46,23 @@ class InstrumentedAttribute(object):
             return self
         return self.get(obj)
 
+    def is_equal(self, x, y):
+        return self._compare_function(x, y)
+    def copy(self, value):
+        return self._copyfunc(value)
+    
+    def check_mutable_modified(self, obj):
+        if self._check_mutable_modified:
+            h = self.get_history(obj, passive=True)
+            if h is not None and h.is_modified():
+                obj._state['modified'] = True
+                return True
+            else:
+                return False
+        else:
+            return False
+            
+        
     def hasparent(self, item, optimistic=False):
         """return the boolean value of a "hasparent" flag attached to the given item.
         
@@ -490,16 +522,14 @@ class CommittedState(object):
             if obj.__dict__.has_key(attr.key):
                 value = obj.__dict__[attr.key]
         if value is not False:
-            if attr.uselist:
-                self.data[attr.key] = [x for x in value]
-                # not tracking parent on lazy-loaded instances at the moment.
-                # its not needed since they will be "optimistically" tested
+            self.data[attr.key] = attr.copy(value)
+
+            # not tracking parent on lazy-loaded instances at the moment.
+            # its not needed since they will be "optimistically" tested
+            #if attr.uselist:
                 #if attr.trackparent:
                 #    [attr.sethasparent(x, True) for x in self.data[attr.key] if x is not None]
-            else:
-                self.data[attr.key] = value
-                # not tracking parent on lazy-loaded instances at the moment.
-                # its not needed since they will be "optimistically" tested
+            #else:
                 #if attr.trackparent and value is not None:
                 #    attr.sethasparent(value, True)
 
@@ -550,7 +580,7 @@ class AttributeHistory(object):
                 if a not in self._unchanged_items:
                     self._deleted_items.append(a)    
         else:
-            if current is original:
+            if attr.is_equal(current, original):
                 self._unchanged_items = [current]
                 self._added_items = []
                 self._deleted_items = []
@@ -564,6 +594,8 @@ class AttributeHistory(object):
         #print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items
     def __iter__(self):
         return iter(self._current)
+    def is_modified(self):
+        return len(self._deleted_items) > 0 or len(self._added_items) > 0
     def added_items(self):
         return self._added_items
     def unchanged_items(self):
@@ -622,6 +654,9 @@ class AttributeManager(object):
                 yield value
                 
     def is_modified(self, object):
+        for attr in self.managed_attributes(object.__class__):
+            if attr.check_mutable_modified(object):
+                return True
         return object._state.get('modified', False)
         
     def init_attr(self, obj):
index 276510b18c2ec60dd5240caa839a2c5fc77aa048..0442a58f03e9da245f72ea3b6acd29a10826b9c0 100644 (file)
@@ -18,12 +18,13 @@ from session import Session as create_session
 
 __all__ = ['relation', 'backref', 'eagerload', 'lazyload', 'noload', 'deferred', 'defer', 'undefer',
         'mapper', 'clear_mappers', 'sql', 'extension', 'class_mapper', 'object_mapper', 'MapperExtension', 'Query', 
-        'cascade_mappers', 'polymorphic_union', 'create_session', 'synonym', 'EXT_PASS' 
+        'cascade_mappers', 'polymorphic_union', 'create_session', 'synonym', 'EXT_PASS'
         ]
 
 def relation(*args, **kwargs):
-    """provides a relationship of a primary Mapper to a secondary Mapper, which corresponds
-    to a parent-child or associative table relationship."""
+    """provide a relationship of a primary Mapper to a secondary Mapper.
+    
+    This corresponds to a parent-child or associative table relationship."""
     if len(args) > 1 and isinstance(args[0], type):
         raise exceptions.ArgumentError("relation(class, table, **kwargs) is deprecated.  Please use relation(class, **kwargs) or relation(mapper, **kwargs).")
     return _relation_loader(*args, **kwargs)
index 0f298a1fbe4c8b6e51d7063b719f79a1dd76a594..d2fadc2573bb8210d95de628b51997511d1bbc75 100644 (file)
@@ -460,7 +460,7 @@ class Mapper(object):
             # save()d to some session.
             if session is not None and mapper is not None:
                 self._entity_name = entity_name
-                session._register_new(self)
+                session._register_pending(self)
                 
             if oldinit is not None:
                 try:
@@ -637,7 +637,7 @@ class Mapper(object):
         for value in imap.values():
             if value is scratch:
                 continue
-            session._register_clean(value)
+            session._register_persistent(value)
             
         if mappers:
             return [result.data] + [o.data for o in otherresults]
index 5015cb3236deb78620651b932dc73f38c7f25808..c78593fcbaca99be365606ec720aae1ffeac82e6 100644 (file)
@@ -18,10 +18,11 @@ import sets, random
 
 class ColumnProperty(mapper.MapperProperty):
     """describes an object attribute that corresponds to a table column."""
-    def __init__(self, *columns):
+    def __init__(self, *columns, **kwargs):
         """the list of columns describes a single object property. if there
         are multiple tables joined together for the mapper, this list represents
         the equivalent column as it appears across each table."""
+        self.deepcheck = kwargs.get('deepcheck', False)
         self.columns = list(columns)
     def getattr(self, object):
         return getattr(object, self.key, None)
@@ -41,7 +42,7 @@ class ColumnProperty(mapper.MapperProperty):
         # establish a SmartProperty property manager on the object for this key
         if self.is_primary():
             #print "regiser col on class %s key %s" % (parent.class_.__name__, key)
-            sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist = False)
+            sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, copy_function=lambda x: self.columns[0].type.copy_value(x), compare_function=lambda x,y:self.columns[0].type.compare_values(x,y))
     def execute(self, session, instance, row, identitykey, imap, isnew):
         if isnew:
             #print "POPULATING OBJ", instance.__class__.__name__, "COL", self.columns[0]._label, "WITH DATA", row[self.columns[0]], "ROW IS A", row.__class__.__name__, "COL ID", id(self.columns[0])
@@ -60,15 +61,15 @@ class DeferredColumnProperty(ColumnProperty):
     """describes an object attribute that corresponds to a table column, which also
     will "lazy load" its value from the table.  this is per-column lazy loading."""
     def __init__(self, *columns, **kwargs):
-        self.group = kwargs.get('group', None)
-        ColumnProperty.__init__(self, *columns)
+        self.group = kwargs.pop('group', None)
+        ColumnProperty.__init__(self, *columns, **kwargs)
     def copy(self):
         return DeferredColumnProperty(*self.columns)
     def do_init(self):
         # establish a SmartProperty property manager on the object for this key, 
         # containing a callable to load in the attribute
         if self.is_primary():
-            sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=lambda i:self.setup_loader(i))
+            sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=lambda i:self.setup_loader(i), copy_function=lambda x: self.columns[0].type.copy_value(x), compare_function=lambda x,y:self.columns[0].type.compare_values(x,y))
     def setup_loader(self, instance):
         if not self.localparent.is_assigned(instance):
             return mapper.object_mapper(instance).props[self.key].setup_loader(instance)
@@ -789,6 +790,14 @@ class DeferredOption(GenericOption):
             prop = ColumnProperty(*oldprop.columns, **self.kwargs)
         mapper._compile_property(key, prop)
         
+class DeferGroupOption(mapper.MapperOption):
+    def __init__(self, group, defer=False, **kwargs):
+        self.group = group
+        self.defer = defer
+        self.kwargs = kwargs
+    def process(self, mapper):
+        self.process_by_key(mapper, self.key)
+    
 
 class BinaryVisitor(sql.ClauseVisitor):
     def __init__(self, func):
index e5c95cbe4beea53e98dd49d4137f25f0c1ab16c9..1858ae6645f4f3f1cf163c5d8172075054d3d5b3 100644 (file)
@@ -357,30 +357,19 @@ class Session(object):
             #    raise exceptions.InvalidRequestError("Instance '%s' is an orphan, and must be attached to a parent object to be saved" % (repr(object)))
             
             m._assign_entity_name(object)
-            self._register_new(object)
+            self._register_pending(object)
 
     def _update_impl(self, object, **kwargs):
         if self._is_attached(object) and object not in self.deleted:
             return
         if not hasattr(object, '_instance_key'):
             raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % repr(object))
-        if attribute_manager.is_modified(object):
-            self._register_dirty(object)
-        else:
-            self._register_clean(object)
+        self._register_persistent(object)
     
-    def _register_changed(self, obj):
-        if hasattr(obj, '_instance_key'):
-            self._register_dirty(obj)
-        else:
-            self._register_new(obj)
-    def _register_new(self, obj):
+    def _register_pending(self, obj):
         self._attach(obj)
         self.uow.register_new(obj)
-    def _register_dirty(self, obj):
-        self._attach(obj)
-        self.uow.register_dirty(obj)
-    def _register_clean(self, obj):
+    def _register_persistent(self, obj):
         self._attach(obj)
         self.uow.register_clean(obj)
     def _register_deleted(self, obj):
@@ -430,7 +419,7 @@ class Session(object):
     def has_key(self, key):
         return self.identity_map.has_key(key)
         
-    dirty = property(lambda s:s.uow.dirty, doc="a Set of all objects marked as 'dirty' within this Session")
+    dirty = property(lambda s:s.uow.locate_dirty(), doc="a Set of all objects marked as 'dirty' within this Session")
     deleted = property(lambda s:s.uow.deleted, doc="a Set of all objects marked as 'deleted' within this Session")
     new = property(lambda s:s.uow.new, doc="a Set of all objects marked as 'new' within this Session.")
     identity_map = property(lambda s:s.uow.identity_map, doc="a WeakValueDictionary consisting of all objects within this Session keyed to their _instance_key value.")
index 0a5669227180fae96e3ae8a2dbd8239d54f1acf2..113a60dda34487d36e76785f1d9e5ce8b285d18d 100644 (file)
@@ -23,11 +23,6 @@ import weakref
 import topological
 import sets
 
-# a global indicating if all flush() operations should have their plan
-# printed to standard output.  also can be affected by creating an engine
-# with the "echo_uow=True" keyword argument.
-LOG = False
-
 class UOWEventHandler(attributes.AttributeExtension):
     """an event handler added to all class attributes which handles session operations."""
     def __init__(self, key, class_, cascade=None):
@@ -35,9 +30,9 @@ class UOWEventHandler(attributes.AttributeExtension):
         self.class_ = class_
         self.cascade = cascade
     def append(self, event, obj, item):
+        # process "save_update" cascade rules for when an instance is appended to the list of another instance
         sess = object_session(obj)
         if sess is not None:
-            sess._register_changed(obj)
             if self.cascade is not None and self.cascade.save_update and item not in sess:
                 mapper = object_mapper(obj)
                 prop = mapper.props[self.key]
@@ -45,14 +40,14 @@ class UOWEventHandler(attributes.AttributeExtension):
                 sess.save_or_update(item, entity_name=ename)
 
     def delete(self, event, obj, item):
-        sess = object_session(obj)
-        if sess is not None:
-            sess._register_changed(obj)
+        # currently no cascade rules for removing an item from a list
+        # (i.e. it stays in the Session)
+        pass
 
     def set(self, event, obj, newvalue, oldvalue):
+        # process "save_update" cascade rules for when an instance is attached to another instance
         sess = object_session(obj)
         if sess is not None:
-            sess._register_changed(obj)
             if newvalue is not None and self.cascade is not None and self.cascade.save_update and newvalue not in sess:
                 mapper = object_mapper(obj)
                 prop = mapper.props[self.key]
@@ -75,7 +70,6 @@ class UOWAttributeManager(attributes.AttributeManager):
     def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs):
         return UOWProperty(self, class_, key, uselist, callable_, typecallable, **kwargs)
 
-
 class UnitOfWork(object):
     """main UOW object which stores lists of dirty/new/deleted objects.  provides top-level "flush" functionality as well as the transaction boundaries with the SQLEngine(s) involved in a write operation."""
     def __init__(self, identity_map=None):
@@ -85,8 +79,6 @@ class UnitOfWork(object):
             self.identity_map = weakref.WeakValueDictionary()
             
         self.new = util.Set() #OrderedSet()
-        self.dirty = util.Set()
-        
         self.deleted = util.Set()
 
     def _remove_deleted(self, obj):
@@ -96,10 +88,6 @@ class UnitOfWork(object):
             self.deleted.remove(obj)
         except KeyError:
             pass
-        try:
-            self.dirty.remove(obj)
-        except KeyError:
-            pass
         try:
             self.new.remove(obj)
         except KeyError:
@@ -110,12 +98,6 @@ class UnitOfWork(object):
             (not hasattr(obj, '_instance_key') and obj not in self.new):
             raise InvalidRequestError("Instance '%s' is not attached or pending within this session" % repr(obj))
 
-    def update(self, obj):
-        """called to add an object to this UnitOfWork as though it were loaded from the DB,
-        but is actually coming from somewhere else, like a web session or similar."""
-        self.identity_map[obj._instance_key] = obj
-        self.register_dirty(obj)
-        
     def register_attribute(self, class_, key, uselist, **kwargs):
         attribute_manager.register_attribute(class_, key, uselist, **kwargs)
 
@@ -123,10 +105,6 @@ class UnitOfWork(object):
         attribute_manager.set_callable(obj, key, func, uselist, **kwargs)
     
     def register_clean(self, obj):
-        try:
-            self.dirty.remove(obj)
-        except KeyError:
-            pass
         try:
             self.new.remove(obj)
         except KeyError:
@@ -147,44 +125,39 @@ class UnitOfWork(object):
         if obj not in self.new:
             self.new.add(obj)
             obj._sa_insert_order = len(self.new)
-        self.unregister_deleted(obj)
-        
-    def register_dirty(self, obj):
-        if obj not in self.dirty:
-            self._validate_obj(obj)
-            self.dirty.add(obj)
-        self.unregister_deleted(obj)
-        
-    def is_dirty(self, obj):
-        if obj not in self.dirty:
-            return False
-        else:
-            return True
         
     def register_deleted(self, obj):
         if obj not in self.deleted:
             self._validate_obj(obj)
             self.deleted.add(obj)  
-
-    def unregister_deleted(self, obj):
-        try:
-            self.deleted.remove(obj)
-        except KeyError:
-            pass
-            
+    
+    def locate_dirty(self):
+        return util.Set([x for x in self.identity_map.values() if x not in self.deleted and attribute_manager.is_modified(x)])
+        
     def flush(self, session, objects=None, echo=False):
+        # this context will track all the objects we want to save/update/delete,
+        # and organize a hierarchical dependency structure.  it also handles
+        # communication with the mappers and relationships to fire off SQL
+        # and synchronize attributes between related objects.
         flush_context = UOWTransaction(self, session)
 
+        # create the set of all objects we want to operate upon
         if objects is not None:
+            # specific list passed in
             objset = util.Set(objects)
         else:
-            objset = None
+            # or just everything
+            objset = util.Set(self.identity_map.values()).union(self.new)
+
+        # detect persistent objects that have changes
+        dirty = self.locate_dirty()
 
+        # store objects whose fate has been decided
         processed = util.Set()
-        for obj in [n for n in self.new] + [d for d in self.dirty]:
-            if objset is not None and not obj in objset:
-                continue
-            if obj in self.deleted or obj in processed:
+        
+        # put all saves/updates into the flush context.  detect orphans and throw them into deleted.
+        for obj in self.new.union(dirty).intersection(objset).difference(self.deleted):
+            if obj in processed:
                 continue
             if object_mapper(obj)._is_orphan(obj):
                 for c in [obj] + list(object_mapper(obj).cascade_iterator('delete', obj)):
@@ -195,7 +168,8 @@ class UnitOfWork(object):
             else:
                 flush_context.register_object(obj)
                 processed.add(obj)
-                
+        
+        # put all remaining deletes into the flush context.
         for obj in self.deleted:
             if (objset is not None and not obj in objset) or obj in processed:
                 continue
@@ -211,19 +185,7 @@ class UnitOfWork(object):
         trans.commit()
             
         flush_context.post_exec()
-        
 
-    def rollback_object(self, obj):
-        """'rolls back' the attributes that have been changed on an object instance."""
-        attribute_manager.rollback(obj)
-        try:
-            self.dirty.remove(obj)
-        except KeyError:
-            pass
-        try:
-            self.deleted.remove(obj)
-        except KeyError:
-            pass
             
 class UOWTransaction(object):
     """handles the details of organizing and executing transaction tasks 
@@ -374,7 +336,7 @@ class UOWTransaction(object):
         
         head = self._sort_dependencies()
         self.__modified = False
-        if LOG or echo:
+        if echo:
             if head is None:
                 print "Task dump: None"
             else:
@@ -383,7 +345,7 @@ class UOWTransaction(object):
             head.execute(self)
         #if self.__modified and head is not None:
         #    raise "Assertion failed ! new pre-execute dependency step should eliminate post-execute changes (except post_update stuff)."
-        if LOG or echo:
+        if echo:
             print "\nExecute complete\n"
             
     def post_exec(self):
index 86c94eea6c1c64660b5d6cdfd9efb88c7a6f2a15..5463c23968888936cea242711663b7d10eac86b3 100644 (file)
@@ -27,6 +27,12 @@ class AbstractType(object):
             return self._impl_dict
     impl_dict = property(_get_impl_dict)
 
+    def copy_value(self, value):
+        return value
+    def compare_values(self, x, y):
+        return x is y
+    def is_mutable(self):
+        return False
             
 class TypeEngine(AbstractType):
     def __init__(self, *args, **params):
@@ -85,7 +91,16 @@ class TypeDecorator(AbstractType):
         instance = self.__class__.__new__(self.__class__)
         instance.__dict__.update(self.__dict__)
         return instance
-        
+
+class MutableType(object):
+    """a mixin that marks a Type as holding a mutable object"""
+    def is_mutable(self):
+        return True
+    def copy_value(self, value):
+        raise NotImplementedError()
+    def compare_values(self, x, y):
+        return x == y
+    
 def to_instance(typeobj):
     if typeobj is None:
         return NULLTYPE
@@ -210,7 +225,7 @@ class Binary(TypeEngine):
     def adapt(self, impltype):
         return impltype(length=self.length)
 
-class PickleType(TypeDecorator):
+class PickleType(MutableType, TypeDecorator):
     impl = Binary
     def __init__(self, protocol=pickle.HIGHEST_PROTOCOL, pickler=None):
        self.protocol = protocol
@@ -225,7 +240,11 @@ class PickleType(TypeDecorator):
       if value is None:
           return None
       return self.impl.convert_bind_param(self.pickler.dumps(value, self.protocol), dialect)
-
+    def copy_value(self, value):
+      return self.pickler.loads(self.pickler.dumps(value, self.protocol))
+    def compare_values(self, x, y):
+        return self.pickler.dumps(x, self.protocol) == self.pickler.dumps(y, self.protocol)
+        
 class Boolean(TypeEngine):
     pass
 
index d1c6978e127c15ea7bde086840cb34b17a187259..ca3937a98c0c6d9f636277d175c5fbf8835ae255 100644 (file)
@@ -294,6 +294,26 @@ class AttributesTest(PersistTest):
         b2.element = None
         assert not manager.get_history(b2, 'element').hasparent(f2)
 
+    def testaggressivediffs(self):
+        """test the 'double check for changes' behavior of check_modified"""
+        class Foo(object):pass
+        manager = attributes.AttributeManager()
+        manager.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x])
+        x = Foo()
+        x.element = ['one', 'two', 'three']    
+        manager.commit(x)
+        x.element[1] = 'five'
+        assert manager.is_modified(x)
+        
+        manager.reset_class_managed(Foo)
+        manager = attributes.AttributeManager()
+        manager.register_attribute(Foo, 'element', uselist=False)
+        x = Foo()
+        x.element = ['one', 'two', 'three']    
+        manager.commit(x)
+        x.element[1] = 'five'
+        assert not manager.is_modified(x)
+        
     def testdescriptorattributes(self):
         """changeset: 1633 broke ability to use ORM to map classes with unusual
         descriptor attributes (for example, classes that inherit from ones
index dd2836469707624903cf2aa0123c8a2607c4a42d..235edbdccac9853290799d71eca1c6a9eb1cf284 100644 (file)
@@ -592,7 +592,6 @@ class SelfReferentialPostUpdateTest(AssertMixin):
         remove_child(root, cats)
         # pre-trigger lazy loader on 'cats' to make the test easier
         cats.children
-
         self.assert_sql(db, lambda: session.flush(), [
             (
                 "UPDATE node SET prev_sibling_id=:prev_sibling_id WHERE node.id = :node_id",
index 35c5378fa526dff03458fa28065c6d9f37fa88d0..3642fef18099d21c6a029e0c764d845f6a574ec1 100644 (file)
@@ -1,8 +1,7 @@
 from testbase import PersistTest, AssertMixin
-import unittest, sys, os
 from sqlalchemy import *
-import StringIO
 import testbase
+import pickleable
 from sqlalchemy.orm.mapper import global_extensions
 from sqlalchemy.ext.sessioncontext import SessionContext
 import sqlalchemy.ext.assignmapper as assignmapper
@@ -34,36 +33,6 @@ class HistoryTest(UnitOfWorkTest):
         users.drop()
         UnitOfWorkTest.tearDownAll(self)
         
-    def testattr(self):
-        """tests the rolling back of scalar and list attributes.  this kind of thing
-        should be tested mostly in attributes.py which tests independently of the ORM 
-        objects, but I think here we are going for
-        the Mapper not interfering with it."""
-        m = mapper(User, users, properties = dict(addresses = relation(mapper(Address, addresses))))
-        u = User()
-        u.user_id = 7
-        u.user_name = 'afdas'
-        u.addresses.append(Address())
-        u.addresses[0].email_address = 'hi'
-        u.addresses.append(Address())
-        u.addresses[1].email_address = 'there'
-        data = [User,
-            {'user_name' : 'afdas',
-             'addresses' : (Address, [{'email_address':'hi'}, {'email_address':'there'}])
-            },
-        ]
-        self.assert_result([u], data[0], *data[1:])
-
-        self.echo(repr(u.addresses))
-        ctx.current.uow.rollback_object(u)
-        
-        # depending on the setting in the get() method of InstrumentedAttribute in attributes.py, 
-        # username is either None or is a non-present attribute.
-        assert u.user_name is None
-        #assert not hasattr(u, 'user_name')
-        
-        assert u.addresses == []
-
     def testbackref(self):
         s = create_session()
         class User(object):pass
@@ -261,6 +230,55 @@ class UnicodeTest(UnitOfWorkTest):
         ctx.current.clear()
         t1 = ctx.current.query(Test).get_by(id=t1.id)
         assert len(t1.t2s) == 2
+
+class MutableTypesTest(UnitOfWorkTest):
+    def setUpAll(self):
+        UnitOfWorkTest.setUpAll(self)
+        global metadata, table
+        metadata = BoundMetaData(testbase.db)
+        table = Table('mutabletest', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', PickleType, nullable=False))
+        table.create()
+    def tearDownAll(self):
+        table.drop()
+        UnitOfWorkTest.tearDownAll(self)
+
+    def testbasic(self):
+        """test that types marked as MutableType get changes detected on them"""
+        class Foo(object):pass
+        mapper(Foo, table)
+        f1 = Foo()
+        f1.data = pickleable.Bar(4,5)
+        ctx.current.flush()
+        ctx.current.clear()
+        f2 = ctx.current.query(Foo).get_by(id=f1.id)
+        assert f2.data == f1.data
+        f2.data.y = 19
+        ctx.current.flush()
+        ctx.current.clear()
+        f3 = ctx.current.query(Foo).get_by(id=f1.id)
+        print f2.data, f3.data
+        assert f3.data != f1.data
+        assert f3.data == pickleable.Bar(4, 19)
+
+    def testnocomparison(self):
+        """test that types marked as MutableType get changes detected on them when the type has no __eq__ method"""
+        class Foo(object):pass
+        mapper(Foo, table)
+        f1 = Foo()
+        f1.data = pickleable.BarWithoutCompare(4,5)
+        ctx.current.flush()
+        ctx.current.clear()
+        f2 = ctx.current.query(Foo).get_by(id=f1.id)
+        f2.data.y = 19
+        ctx.current.flush()
+        ctx.current.clear()
+        f3 = ctx.current.query(Foo).get_by(id=f1.id)
+        print f2.data, f3.data
+        assert (f3.data.x, f3.data.y) == (4,19)
+        
+        
         
 class PKTest(UnitOfWorkTest):
     @testbase.unsupported('mssql')
diff --git a/test/pickleable.py b/test/pickleable.py
new file mode 100644 (file)
index 0000000..fc4f691
--- /dev/null
@@ -0,0 +1,29 @@
+"""since the cPickle module as of py2.4 uses erroneous relative imports, define the various
+picklable classes here so we can test PickleType stuff without issue."""
+
+
+class Foo(object):
+    def __init__(self, moredata):
+        self.data = 'im data'
+        self.stuff = 'im stuff'
+        self.moredata = moredata
+    def __eq__(self, other):
+        return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata
+
+
+class Bar(object):
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+    def __eq__(self, other):
+        return other.__class__ is self.__class__ and other.x==self.x and other.y==self.y
+    def __str__(self):
+        return "Bar(%d, %d)" % (self.x, self.y)
+
+class BarWithoutCompare(object):
+    def __init__(self, x, y):
+        self.x = x
+        self.y = y
+    def __str__(self):
+        return "Bar(%d, %d)" % (self.x, self.y)
+    
\ No newline at end of file
index a3c1c0f1826fe7bdbced5d38aca9a793ab9e88ea..524834fec5bba91921007ba559bbed8a4fb0d3ff 100644 (file)
@@ -1,5 +1,6 @@
 from testbase import PersistTest, AssertMixin
 import testbase
+import pickleable
 from sqlalchemy import *
 import string,datetime, re, sys
 import sqlalchemy.engine.url as url
@@ -164,15 +165,6 @@ class UnicodeTest(AssertMixin):
         finally:
             db.engine.dialect.convert_unicode = prev_unicode
 
-class Foo(object):
-    def __init__(self, moredata):
-        self.data = 'im data'
-        self.stuff = 'im stuff'
-        self.moredata = moredata
-    def __eq__(self, other):
-        return other.data == self.data and other.stuff == self.stuff and other.moredata==self.moredata
-
-import pickle
 
 class BinaryTest(AssertMixin):
     def setUpAll(self):
@@ -185,14 +177,14 @@ class BinaryTest(AssertMixin):
         # construct PickleType with non-native pickle module, since cPickle uses relative module
         # loading and confuses this test's parent package 'sql' with the 'sqlalchemy.sql' package relative
        # to the 'types' module
-        Column('pickled', PickleType(pickler=pickle))
+        Column('pickled', PickleType)
         )
         binary_table.create()
     def tearDownAll(self):
         binary_table.drop()
     def testbinary(self):
-        testobj1 = Foo('im foo 1')
-        testobj2 = Foo('im foo 2')
+        testobj1 = pickleable.Foo('im foo 1')
+        testobj2 = pickleable.Foo('im foo 2')
 
         stream1 =self.get_module_stream('sqlalchemy.sql')
         stream2 =self.get_module_stream('sqlalchemy.schema')