]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Removed equality, truth and hash() testing of mapped instances. Mapped
authorJason Kirtland <jek@discorporate.us>
Sat, 3 Nov 2007 20:23:26 +0000 (20:23 +0000)
committerJason Kirtland <jek@discorporate.us>
Sat, 3 Nov 2007 20:23:26 +0000 (20:23 +0000)
  classes can now implement arbitrary __eq__ and friends. [ticket:676]

13 files changed:
CHANGES
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/topological.py
lib/sqlalchemy/util.py
test/orm/mapper.py
test/orm/unitofwork.py

diff --git a/CHANGES b/CHANGES
index c04dfe8ed6a7a731d461d0fe34316069562144b2..d58cce785728528b2879f3c80f084be500b655ce 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -4,6 +4,7 @@ CHANGES
 
 0.4.1
 -----
+
 - removed regular expression step from most statement compilations.
   also fixes [ticket:833]
 
@@ -37,6 +38,10 @@ CHANGES
   dialects using the old name.
 
 - orm:
+    - Mapped classes may now define __eq__, __hash__, and __nonzero__ methods
+      with arbitrary sementics.  The orm now handles all mapped instances on
+      an identity-only basis. (e.g. 'is' vs '==') [ticket:676]
+
     - deferred column attributes no longer trigger a load operation when the
       attribute is assigned to.  in those cases, the newly assigned
       value will be present in the flushes' UPDATE statement unconditionally.
@@ -68,7 +73,7 @@ CHANGES
       scoped sessions.
 
   - session API has been solidified:
-  
+
     - it's an error to session.save() an object which is already persistent
       [ticket:840]
 
@@ -76,7 +81,7 @@ CHANGES
 
     - session.update() and session.delete() raise an error when updating/deleting
       an instance that is already in the session with a different identity.
-      
+
     - session checks more carefully when determining "object X already in another session";
       e.g. if you pickle a series of objects and unpickle (i.e. as in a Pylons HTTP session
       or similar), they can go into a new session without any conflict
index 0ee59e3690f241180fc5e088721d52796bc96053..472bd1b2cc790fc6f878f2bab634eb1fede2587f 100644 (file)
@@ -326,6 +326,7 @@ class _AssociationList(object):
 
     def __contains__(self, value):
         for member in self.col:
+            # testlib.pragma exempt:__eq__
             if self._get(member) == value:
                 return True
         return False
@@ -473,6 +474,7 @@ class _AssociationDict(object):
         del self.col[key]
 
     def __contains__(self, key):
+        # testlib.pragma exempt:__hash__
         return key in self.col
     has_key = __contains__
 
@@ -609,6 +611,7 @@ class _AssociationSet(object):
 
     def __contains__(self, value):
         for member in self.col:
+            # testlib.pragma exempt:__eq__
             if self._get(member) == value:
                 return True
         return False
index 189cd52ee07cd356a942a4ed95e097258f6b8cd4..a340394b9bb93f1f75c0905e7addd26d37a85c12 100644 (file)
@@ -8,7 +8,7 @@ import weakref, threading
 import UserDict
 from sqlalchemy import util
 from sqlalchemy.orm import interfaces, collections
-from sqlalchemy.orm.mapper import class_mapper
+from sqlalchemy.orm.mapper import class_mapper, identity_equal
 from sqlalchemy import exceptions
 
 
@@ -369,6 +369,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         super(ScalarObjectAttributeImpl, self).__init__(class_, manager, key,
           callable_, trackparent=trackparent, extension=extension,
           compare_function=compare_function, mutable_scalars=mutable_scalars, **kwargs)
+        if compare_function is None:
+            self.is_equal = identity_equal
 
     def delete(self, state):
         old = self.get(state)
@@ -815,23 +817,26 @@ class AttributeHistory(object):
 
         if hasattr(attr, 'get_collection'):
             self._current = current
+
             if original is NO_VALUE:
-                s = util.Set([])
+                s = util.IdentitySet([])
             else:
-                s = util.Set(original)
-            self._added_items = []
-            self._unchanged_items = []
-            self._deleted_items = []
+                s = util.IdentitySet(original)
+
+            # FIXME: the tests have an assumption on the collection's ordering
+            self._added_items = util.OrderedIdentitySet()
+            self._unchanged_items = util.OrderedIdentitySet()
+            self._deleted_items = util.OrderedIdentitySet()
             if current:
                 collection = attr.get_collection(state, current)
                 for a in collection:
                     if a in s:
-                        self._unchanged_items.append(a)
+                        self._unchanged_items.add(a)
                     else:
-                        self._added_items.append(a)
+                        self._added_items.add(a)
             for a in s:
                 if a not in self._unchanged_items:
-                    self._deleted_items.append(a)
+                    self._deleted_items.add(a)
         else:
             self._current = [current]
             if attr.is_equal(current, original) is True:
@@ -853,13 +858,13 @@ class AttributeHistory(object):
         return len(self._deleted_items) > 0 or len(self._added_items) > 0
 
     def added_items(self):
-        return self._added_items
+        return list(self._added_items)
 
     def unchanged_items(self):
-        return self._unchanged_items
+        return list(self._unchanged_items)
 
     def deleted_items(self):
-        return self._deleted_items
+        return list(self._deleted_items)
 
 class AttributeManager(object):
     """Allow the instrumentation of object attributes."""
index bf365d267851a81216c7eb06d83e46cfb0c85200..9e6b0ce75670310c4af78db4ee3c2d9a99df7236 100644 (file)
@@ -793,6 +793,7 @@ def _list_decorators():
 
     def remove(fn):
         def remove(self, value, _sa_initiator=None):
+            # testlib.pragma exempt:__eq__
             fn(self, value)
             __del(self, value, _sa_initiator)
         _tidy(remove)
@@ -1002,22 +1003,27 @@ def _set_decorators():
     def add(fn):
         def add(self, value, _sa_initiator=None):
             __set(self, value, _sa_initiator)
+            # testlib.pragma exempt:__hash__
             fn(self, value)
         _tidy(add)
         return add
 
     def discard(fn):
         def discard(self, value, _sa_initiator=None):
+            # testlib.pragma exempt:__hash__
             if value in self:
                 __del(self, value, _sa_initiator)
+            # testlib.pragma exempt:__hash__
             fn(self, value)
         _tidy(discard)
         return discard
 
     def remove(fn):
         def remove(self, value, _sa_initiator=None):
+            # testlib.pragma exempt:__hash__
             if value in self:
                 __del(self, value, _sa_initiator)
+            # testlib.pragma exempt:__hash__
             fn(self, value)
         _tidy(remove)
         return remove
index a1669e32f53422165c22bd3ce7aebeb31c8871bc..f771dc5d729fa444244f049f16787059bc034717 100644 (file)
@@ -345,29 +345,29 @@ class ManyToManyDP(DependencyProcessor):
                 childlist = self.get_object_dependencies(obj, uowcommit, passive=self.passive_deletes)
                 if childlist is not None:
                     for child in childlist.deleted_items() + childlist.unchanged_items():
-                        if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes):
+                        if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes):
                             continue
                         associationrow = {}
                         self._synchronize(obj, child, associationrow, False, uowcommit)
                         secondary_delete.append(associationrow)
-                        uowcommit.attributes[(self, "manytomany", obj, child)] = True
+                        uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True
         else:
             for obj in deplist:
                 childlist = self.get_object_dependencies(obj, uowcommit)
                 if childlist is None: continue
                 for child in childlist.added_items():
-                    if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes):
+                    if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes):
                         continue
                     associationrow = {}
                     self._synchronize(obj, child, associationrow, False, uowcommit)
-                    uowcommit.attributes[(self, "manytomany", obj, child)] = True
+                    uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True
                     secondary_insert.append(associationrow)
                 for child in childlist.deleted_items():
-                    if child is None or (reverse_dep and (reverse_dep, "manytomany", child, obj) in uowcommit.attributes):
+                    if child is None or (reverse_dep and (reverse_dep, "manytomany", id(child), id(obj)) in uowcommit.attributes):
                         continue
                     associationrow = {}
                     self._synchronize(obj, child, associationrow, False, uowcommit)
-                    uowcommit.attributes[(self, "manytomany", obj, child)] = True
+                    uowcommit.attributes[(self, "manytomany", id(obj), id(child))] = True
                     secondary_delete.append(associationrow)
 
         if secondary_delete:
index 73c8321fc79665a4fa3569ec84c5747270e425fe..efc509725220f9b3ca96d07c7d410e117fc8ffa8 100644 (file)
@@ -1099,7 +1099,8 @@ class Mapper(object):
                     c = connection.execute(statement.values(value_params), params)
                     mapper._postfetch(connection, table, obj, c, c.last_updated_params(), value_params)
 
-                    updated_objects.add((obj, connection))
+                    # testlib.pragma exempt:__hash__
+                    updated_objects.add((id(obj), obj, connection))
                     rows += c.rowcount
 
                 if c.supports_sane_rowcount() and rows != len(update):
@@ -1134,13 +1135,14 @@ class Mapper(object):
                             mapper._synchronizer.execute(obj, obj)
                     sync(mapper)
 
-                    inserted_objects.add((obj, connection))
+                    # testlib.pragma exempt:__hash__
+                    inserted_objects.add((id(obj), obj, connection))
         if not postupdate:
-            for obj, connection in inserted_objects:
+            for id_, obj, connection in inserted_objects:
                 for mapper in object_mapper(obj).iterate_to_root():
                     if 'after_insert' in mapper.extension.methods:
                         mapper.extension.after_insert(mapper, connection, obj)
-            for obj, connection in updated_objects:
+            for id_, obj, connection in updated_objects:
                 for mapper in object_mapper(obj).iterate_to_root():
                     if 'after_update' in mapper.extension.methods:
                         mapper.extension.after_update(mapper, connection, obj)
@@ -1194,7 +1196,7 @@ class Mapper(object):
             for mapper in object_mapper(obj).iterate_to_root():
                 if 'before_delete' in mapper.extension.methods:
                     mapper.extension.before_delete(mapper, connection, obj)
-        
+
         deleted_objects = util.Set()
         table_to_mapper = {}
         for mapper in self.base_mapper.polymorphic_iterator():
@@ -1217,7 +1219,8 @@ class Mapper(object):
                     params[col.key] = mapper.get_attr_by_column(obj, col)
                 if mapper.version_id_col is not None:
                     params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col)
-                deleted_objects.add((obj, connection))
+                # testlib.pragma exempt:__hash__
+                deleted_objects.add((id(obj), obj, connection))
             for connection, del_objects in delete.iteritems():
                 mapper = table_to_mapper[table]
                 def comparator(a, b):
@@ -1237,7 +1240,7 @@ class Mapper(object):
                 if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects):
                     raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects)))
 
-        for obj, connection in deleted_objects:
+        for id_, obj, connection in deleted_objects:
             for mapper in object_mapper(obj).iterate_to_root():
                 if 'after_delete' in mapper.extension.methods:
                     mapper.extension.after_delete(mapper, connection, obj)
@@ -1284,7 +1287,7 @@ class Mapper(object):
         """
 
         if recursive is None:
-            recursive=util.Set()
+            recursive=util.IdentitySet()
         for prop in self.__props.values():
             for c in prop.cascade_iterator(type, object, recursive, halt_on=halt_on):
                 yield c
@@ -1310,7 +1313,7 @@ class Mapper(object):
         """
 
         if recursive is None:
-            recursive=util.Set()
+            recursive=util.IdentitySet()
         for prop in self.__props.values():
             prop.cascade_callable(type, object, callable_, recursive, halt_on=halt_on)
 
@@ -1516,7 +1519,7 @@ class Mapper(object):
             selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags)
             
         if self.non_primary:
-            selectcontext.attributes[('populating_mapper', instance)] = self
+            selectcontext.attributes[('populating_mapper', id(instance))] = self
         
     def _post_instance(self, selectcontext, instance):
         post_processors = selectcontext.attributes[('post_processors', self, None)]
@@ -1577,6 +1580,15 @@ def has_mapper(object):
 
     return hasattr(object, '_entity_name')
 
+def identity_equal(a, b):
+    if a is b:
+        return True
+    id_a = getattr(a, '_instance_key', None)
+    id_b = getattr(b, '_instance_key', None)
+    if id_a is None or id_b is None:
+        return False
+    return id_a == id_b
+
 def object_mapper(object, entity_name=None, raiseerror=True):
     """Given an object, return the primary Mapper associated with the object instance.
     
index bec05a43fff10e5bf9e6fbdc26de5872336b605e..09a3a0f5b71dd9889b68a7160c52c000487a7e64 100644 (file)
@@ -691,7 +691,7 @@ class Query(object):
                 proc[0](context, row)
 
         for instance in context.identity_map.values():
-            context.attributes.get(('populating_mapper', instance), object_mapper(instance))._post_instance(context, instance)
+            context.attributes.get(('populating_mapper', id(instance)), object_mapper(instance))._post_instance(context, instance)
         
         # store new stuff in the identity map
         for instance in context.identity_map.values():
index f643302897ca9b7e94119a97c7787b302459b8c9..b699bfee5341d520da1229cff52360b32bc9c5cd 100644 (file)
@@ -611,8 +611,8 @@ class EagerLoader(AbstractRelationLoader):
                         appender = util.UniqueAppender(collection, 'append_without_event')
 
                         # store it in the "scratch" area, which is local to this load operation.
-                        selectcontext.attributes[(instance, self.key)] = appender
-                    result_list = selectcontext.attributes[(instance, self.key)]
+                        selectcontext.attributes[('appender', id(instance), self.key)] = appender
+                    result_list = selectcontext.attributes[('appender', id(instance), self.key)]
                     if self._should_log_debug:
                         self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key))
                         
index 0ce354d6f3b673a608f0b9ea849efcbba8f573f6..7f9a4d7d06e5f0a1f2454f26929216a7e4726d01 100644 (file)
@@ -93,8 +93,8 @@ class UnitOfWork(object):
         else:
             self.identity_map = {}
 
-        self.new = util.Set() #OrderedSet()
-        self.deleted = util.Set()
+        self.new = util.IdentitySet() #OrderedSet()
+        self.deleted = util.IdentitySet()
         self.logger = logging.instance_logger(self, echoflag=session.echo_uow)
 
     def _remove_deleted(self, obj):
@@ -150,7 +150,7 @@ class UnitOfWork(object):
         """
         
         # a little bit of inlining for speed
-        return util.Set([x for x in self.identity_map.values() 
+        return util.IdentitySet([x for x in self.identity_map.values() 
             if x not in self.deleted 
             and (
                 x._state.modified
@@ -180,13 +180,13 @@ class UnitOfWork(object):
         # create the set of all objects we want to operate upon
         if objects is not None:
             # specific list passed in
-            objset = util.Set(objects)
+            objset = util.IdentitySet(objects)
         else:
             # or just everything
-            objset = util.Set(self.identity_map.values()).union(self.new)
+            objset = util.IdentitySet(self.identity_map.values()).union(self.new)
             
         # store objects whose fate has been decided
-        processed = util.Set()
+        processed = util.IdentitySet()
 
         # put all saves/updates into the flush context.  detect top-level orphans and throw them into deleted.
         for obj in self.new.union(dirty).intersection(objset).difference(self.deleted):
@@ -305,7 +305,7 @@ class UOWTransaction(object):
         """
         mapper = object_mapper(obj)
         task = self.get_task_by_mapper(mapper)
-        taskelement = task._objects[obj]
+        taskelement = task._objects[id(obj)]
         taskelement.isdelete = "rowswitch"
         
     def unregister_object(self, obj):
@@ -315,7 +315,7 @@ class UOWTransaction(object):
         no further operations occur upon the instance."""
         mapper = object_mapper(obj)
         task = self.get_task_by_mapper(mapper)
-        if obj in task._objects:
+        if id(obj) in task._objects:
             task.delete(obj)
 
     def is_deleted(self, obj):
@@ -615,11 +615,11 @@ class UOWTask(object):
         """
 
         try:
-            rec = self._objects[obj]
+            rec = self._objects[id(obj)]
             retval = False
         except KeyError:
             rec = UOWTaskElement(obj)
-            self._objects[obj] = rec
+            self._objects[id(obj)] = rec
             retval = True
         if not listonly:
             rec.listonly = False
@@ -646,7 +646,7 @@ class UOWTask(object):
         """remove the given object from this UOWTask, if present."""
         
         try:
-            del self._objects[obj]
+            del self._objects[id(obj)]
         except KeyError:
             pass
 
@@ -654,7 +654,7 @@ class UOWTask(object):
         """return True if the given object is contained within this UOWTask or inheriting tasks."""
         
         for task in self.polymorphic_tasks():
-            if obj in task._objects:
+            if id(obj) in task._objects:
                 return True
         else:
             return False
@@ -663,7 +663,7 @@ class UOWTask(object):
         """return True if the given object is marked as to be deleted within this UOWTask."""
         
         try:
-            return self._objects[obj].isdelete
+            return self._objects[id(obj)].isdelete
         except KeyError:
             return False
 
@@ -735,9 +735,9 @@ class UOWTask(object):
 
         def get_dependency_task(obj, depprocessor):
             try:
-                dp = dependencies[obj]
+                dp = dependencies[id(obj)]
             except KeyError:
-                dp = dependencies.setdefault(obj, {})
+                dp = dependencies.setdefault(id(obj), {})
             try:
                 l = dp[depprocessor]
             except KeyError:
@@ -766,7 +766,7 @@ class UOWTask(object):
             for subtask in task.polymorphic_tasks():
                 for taskelement in subtask.elements:
                     obj = taskelement.obj
-                    object_to_original_task[obj] = subtask
+                    object_to_original_task[id(obj)] = subtask
                     for dep in deps_by_targettask.get(subtask, []):
                         # is this dependency involved in one of the cycles ?
                         if not dependency_in_cycles(dep):
@@ -795,7 +795,7 @@ class UOWTask(object):
                             # task
                             if o not in childtask:
                                 childtask.append(o, listonly=True)
-                                object_to_original_task[o] = childtask
+                                object_to_original_task[id(o)] = childtask
 
                             # create a tuple representing the "parent/child"
                             whosdep = dep.whose_dependent_on_who(obj, o)
@@ -821,17 +821,17 @@ class UOWTask(object):
         
         used_tasks = util.Set()
         def make_task_tree(node, parenttask, nexttasks):
-            originating_task = object_to_original_task[node.item]
+            originating_task = object_to_original_task[id(node.item)]
             used_tasks.add(originating_task)
             t = nexttasks.get(originating_task, None)
             if t is None:
                 t = UOWTask(self.uowtransaction, originating_task.mapper)
                 nexttasks[originating_task] = t
-                parenttask.append(None, listonly=False, isdelete=originating_task._objects[node.item].isdelete, childtask=t)
-            t.append(node.item, originating_task._objects[node.item].listonly, isdelete=originating_task._objects[node.item].isdelete)
+                parenttask.append(None, listonly=False, isdelete=originating_task._objects[id(node.item)].isdelete, childtask=t)
+            t.append(node.item, originating_task._objects[id(node.item)].listonly, isdelete=originating_task._objects[id(node.item)].isdelete)
 
-            if node.item in dependencies:
-                for depprocessor, deptask in dependencies[node.item].iteritems():
+            if id(node.item) in dependencies:
+                for depprocessor, deptask in dependencies[id(node.item)].iteritems():
                     t.cyclical_dependencies.add(depprocessor.branch(deptask))
             nd = {}
             for n in node.children:
@@ -861,7 +861,7 @@ class UOWTask(object):
                 # or "delete" members due to inheriting mappers which contain tasks
                 localtask = UOWTask(self.uowtransaction, t2.mapper)
                 for obj in t2.elements:
-                    localtask.append(obj, t2.listonly, isdelete=t2._objects[obj].isdelete)
+                    localtask.append(obj, t2.listonly, isdelete=t2._objects[id(obj)].isdelete)
                 for dep in t2.dependencies:
                     localtask._dependencies.add(dep)
                 t.childtasks.insert(0, localtask)
index d38c5cf4a6056d2bce80a53ec81845696b9d1ba8..a47968519df7ff33ab56cad1b76805eedf1218ae 100644 (file)
@@ -153,20 +153,20 @@ class QueueDependencySorter(object):
         nodes = {}
         edges = _EdgeCollection()
         for item in allitems + [t[0] for t in tuples] + [t[1] for t in tuples]:
-            if item not in nodes:
+            if id(item) not in nodes:
                 node = _Node(item)
-                nodes[item] = node
+                nodes[id(item)] = node
 
         for t in tuples:
             if t[0] is t[1]:
                 if allow_self_cycles:
-                    n = nodes[t[0]]
+                    n = nodes[id(t[0])]
                     n.cycles = util.Set([n])
                     continue
                 else:
                     raise CircularDependencyError("Self-referential dependency detected " + repr(t))
-            childnode = nodes[t[1]]
-            parentnode = nodes[t[0]]
+            childnode = nodes[id(t[1])]
+            parentnode = nodes[id(t[0])]
             edges.add((parentnode, childnode))
 
         queue = []
@@ -202,7 +202,7 @@ class QueueDependencySorter(object):
             node = queue.pop()
             if not hasattr(node, '_cyclical'):
                 output.append(node)
-            del nodes[node.item]
+            del nodes[id(node.item)]
             for childnode in edges.pop_node(node):
                 queue.append(childnode)
         return self._create_batched_tree(output)
index a4ccaac6ab88705cdde65baf16feb09f25bdc55e..9ad7e113c41699e72158db13130bf9116758d891 100644 (file)
@@ -620,6 +620,7 @@ class IdentitySet(object):
 
     def union(self, iterable):
         result = type(self)()
+        # testlib.pragma exempt:__hash__
         result._members.update(
             Set(self._members.iteritems()).union(_iter_id(iterable)))
         return result
@@ -641,6 +642,7 @@ class IdentitySet(object):
 
     def difference(self, iterable):
         result = type(self)()
+        # testlib.pragma exempt:__hash__
         result._members.update(
             Set(self._members.iteritems()).difference(_iter_id(iterable)))
         return result
@@ -662,6 +664,7 @@ class IdentitySet(object):
 
     def intersection(self, iterable):
         result = type(self)()
+        # testlib.pragma exempt:__hash__
         result._members.update(
             Set(self._members.iteritems()).intersection(_iter_id(iterable)))
         return result
@@ -683,6 +686,7 @@ class IdentitySet(object):
 
     def symmetric_difference(self, iterable):
         result = type(self)()
+        # testlib.pragma exempt:__hash__
         result._members.update(
             Set(self._members.iteritems()).symmetric_difference(_iter_id(iterable)))
         return result
@@ -725,13 +729,25 @@ def _iter_id(iterable):
         yield id(item), item
 
 
+class OrderedIdentitySet(IdentitySet):
+    def __init__(self, iterable=None):
+        IdentitySet.__init__(self)
+        self._members = OrderedDict()
+        if iterable:
+            for o in iterable:
+                self.add(o)
+
+
 class UniqueAppender(object):
-    """appends items to a collection such that only unique items
-    are added."""
+    """Only adds items to a collection once.
+
+    Additional appends() of the same object are ignored.  Membership is
+    determined by identity (``is a``) not equality (``==``).
+    """
 
     def __init__(self, data, via=None):
         self.data = data
-        self._unique = Set()
+        self._unique = IdentitySet()
         if via:
             self._data_appender = getattr(data, via)
         elif hasattr(data, 'append'):
index c9729944af65d6a09d549922af227c34e35df5a2..b37c985a60a38303f0528f84af1fe8b7328ed7e3 100644 (file)
@@ -1176,5 +1176,175 @@ class MapperExtensionTest(MapperSuperTest):
                 'after_delete', 'after_insert', 'before_update', 'before_insert', 'after_update', 'populate_instance'])
         
 
-if __name__ == "__main__":    
+class RequirementsTest(AssertMixin):
+    """Tests the contract for user classes."""
+
+    def setUpAll(self):
+        global metadata, t1, t2, t3, t4, t5, t6
+
+        metadata = MetaData(testbase.db)
+        t1 = Table('ht1', metadata,
+                   Column('id', Integer, primary_key=True),
+                   Column('value', String(10)))
+        t2 = Table('ht2', metadata,
+                   Column('id', Integer, primary_key=True),
+                   Column('ht1_id', Integer, ForeignKey('ht1.id')),
+                   Column('value', String(10)))
+        t3 = Table('ht3', metadata,
+                   Column('id', Integer, primary_key=True),
+                   Column('value', String(10)))
+        t4 = Table('ht4', metadata,
+                   Column('ht1_id', Integer, ForeignKey('ht1.id'),
+                          primary_key=True),
+                   Column('ht3_id', Integer, ForeignKey('ht3.id'),
+                          primary_key=True))
+        t5 = Table('ht5', metadata,
+                   Column('ht1_id', Integer, ForeignKey('ht1.id'),
+                          primary_key=True),
+                   Column('ht1_id', Integer, ForeignKey('ht1.id'),
+                          primary_key=True))
+        t6 = Table('ht6', metadata,
+                   Column('ht1a_id', Integer, ForeignKey('ht1.id'),
+                          primary_key=True),
+                   Column('ht1b_id', Integer, ForeignKey('ht1.id'),
+                          primary_key=True),
+                   Column('value', String(10)))
+        metadata.create_all()
+
+    def setUp(self):
+        clear_mappers()
+
+    def tearDownAll(self):
+        metadata.drop_all()
+
+    def test_baseclass(self):
+        class OldStyle:
+            pass
+
+        self.assertRaises(exceptions.ArgumentError, mapper, OldStyle, t1)
+
+        class NoWeakrefSupport(str):
+            pass
+
+        # TODO: is weakref support detectable without an instance?
+        #self.assertRaises(exceptions.ArgumentError, mapper, NoWeakrefSupport, t2)
+
+    def test_comparison_overrides(self):
+        """Simple tests to ensure users can supply comparison __methods__.
+
+        The suite-level test --options are better suited to detect
+        problems- they add selected __methods__ across the board on all
+        ORM tests.  This test simply shoves a variety of operations
+        through the ORM to catch basic regressions early in a standard
+        test run.
+        """
+
+        # adding these methods directly to each class to avoid decoration
+        # by the testlib decorators.
+        class H1(object):
+            def __init__(self, value='abc'):
+                self.value = value
+            def __nonzero__(self):
+                return False
+            def __hash__(self):
+                return hash(self.value)
+            def __eq__(self, other):
+                if isinstance(other, type(self)):
+                    return self.value == other.value
+                return False
+        class H2(object):
+            def __init__(self, value='abc'):
+                self.value = value
+            def __nonzero__(self):
+                return False
+            def __hash__(self):
+                return hash(self.value)
+            def __eq__(self, other):
+                if isinstance(other, type(self)):
+                    return self.value == other.value
+                return False
+        class H3(object):
+            def __init__(self, value='abc'):
+                self.value = value
+            def __nonzero__(self):
+                return False
+            def __hash__(self):
+                return hash(self.value)
+            def __eq__(self, other):
+                if isinstance(other, type(self)):
+                    return self.value == other.value
+                return False
+        class H6(object):
+            def __init__(self, value='abc'):
+                self.value = value
+            def __nonzero__(self):
+                return False
+            def __hash__(self):
+                return hash(self.value)
+            def __eq__(self, other):
+                if isinstance(other, type(self)):
+                    return self.value == other.value
+                return False
+
+        mapper(H1, t1, properties={
+            'h2s': relation(H2, backref='h1'),
+            'h3s': relation(H3, secondary=t4, backref='h1s'),
+            'h1s': relation(H1, secondary=t5, backref='parent_h1'),
+            't6a': relation(H6, backref='h1a',
+                            primaryjoin=t1.c.id==t6.c.ht1a_id),
+            't6b': relation(H6, backref='h1b',
+                            primaryjoin=t1.c.id==t6.c.ht1b_id),
+            })
+        mapper(H2, t2)
+        mapper(H3, t3)
+        mapper(H6, t6)
+
+        s = create_session()
+        for i in range(3):
+            h1 = H1()
+            s.save(h1)
+
+        h1.h2s.append(H2())
+        h1.h3s.extend([H3(), H3()])
+        h1.h1s.append(H1())
+
+        s.flush()
+
+        h6 = H6()
+        h6.h1a = h1
+        h6.h1b = h1
+
+        h6 = H6()
+        h6.h1a = h1
+        h6.h1b = H1()
+
+        h6.h1b.h2s.append(H2())
+
+        s.flush()
+
+        h1.h2s.extend([H2(), H2()])
+        s.flush()
+
+        h1s = s.query(H1).options(eagerload('h2s')).all()
+        self.assertEqual(len(h1s), 5)
+
+        self.assert_unordered_result(h1s, H1,
+                                     {'h2s': []},
+                                     {'h2s': []},
+                                     {'h2s': (H2, [{'value': 'abc'},
+                                                   {'value': 'abc'},
+                                                   {'value': 'abc'}])},
+                                     {'h2s': []},
+                                     {'h2s': (H2, [{'value': 'abc'}])})
+
+        h1s = s.query(H1).options(eagerload('h3s')).all()
+
+        self.assertEqual(len(h1s), 5)
+        h1s = s.query(H1).options(eagerload_all('t6a.h1b'),
+                                  eagerload('h2s'),
+                                  eagerload_all('h3s.h1s')).all()
+        self.assertEqual(len(h1s), 5)
+
+
+if __name__ == "__main__":
     testbase.main()
index 3336a0783aa28bdd121cca047bdf25179a7b9085..b985cc8a50f8b1ed19b4383edf219cae005285f6 100644 (file)
@@ -1820,7 +1820,7 @@ class RowSwitchTest(ORMTest):
         sess.flush()
 
         assert list(sess.execute(t1.select(), mapper=T1)) == [(1, 'some t1')]
-        assert list(sess.execute(t1t3.select(), mapper=T1)) == [(1,1), (1, 2)]
+        assert rowset(sess.execute(t1t3.select(), mapper=T1)) == set([(1,1), (1, 2)])
         assert list(sess.execute(t3.select(), mapper=T1)) == [(1, 'some t3'), (2, 'some other t3')]
 
         o2 = T1(data='some other t1', id=1, t3s=[