]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
[ticket:534] get dictionary append() method properly
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 11 Apr 2007 21:15:11 +0000 (21:15 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 11 Apr 2007 21:15:11 +0000 (21:15 +0000)
lib/sqlalchemy/orm/attributes.py
test/orm/relationships.py

index b7d2050ffb6f68b2a959b65fbf6ee094a8258aba..699a2f8875a9b72e22d0ffcb456cfdd490abddde 100644 (file)
@@ -396,7 +396,9 @@ class InstrumentedList(object):
             self._data_appender = self.data.add
             self._clear_data = self._clear_set
         elif isinstance(self.data, dict):
-            if not hasattr(self.data, 'append'):
+            if hasattr(self.data, 'append'):
+                self._data_appender = self.data.append
+            else:
                 raise exceptions.ArgumentError("Dictionary collection class '%s' must implement an append() method" % type(self.data).__name__)
             self._clear_data = self._clear_dict
         else:
index e2ca39c5116cd86a5728fd7cfc2c57c199ec6b1a..fac484975d9f1a6a13bc562c8173717a7773e7a7 100644 (file)
@@ -716,6 +716,34 @@ class CustomCollectionsTest(testbase.ORMTest):
         sess.clear()
         f = sess.query(Foo).get(f.col1)
         assert len(list(f.bars)) == 2
+        f.bars.clear()
+        
+    def testdict(self):
+        """test that a 'dict' can be used as a collection and can lazyload."""
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+        class AppenderDict(dict):
+            def append(self, item):
+                self[id(item)] = item
+            def __iter__(self):
+                return iter(self.values())
+                
+        mapper(Foo, sometable, properties={
+            'bars':relation(Bar, collection_class=AppenderDict)
+        })
+        mapper(Bar, someothertable)
+        f = Foo()
+        f.bars.append(Bar())
+        f.bars.append(Bar())
+        sess = create_session()
+        sess.save(f)
+        sess.flush()
+        sess.clear()
+        f = sess.query(Foo).get(f.col1)
+        assert len(list(f.bars)) == 2
+        f.bars.clear()
 
 class ViewOnlyTest(testbase.ORMTest):
     """test a view_only mapping where a third table is pulled into the primary join condition,