]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- improved/fixed custom collection classes when giving it "set"/
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Mar 2007 19:59:39 +0000 (19:59 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Mar 2007 19:59:39 +0000 (19:59 +0000)
"sets.Set" classes or subclasses (was still looking for append()
methods on them during lazy loads)
- moved CustomCollectionsTest from unitofwork to relationships
- added more custom collections test to attributes module

CHANGES
lib/sqlalchemy/orm/attributes.py
test/orm/attributes.py
test/orm/relationships.py
test/orm/unitofwork.py

diff --git a/CHANGES b/CHANGES
index 71843f36000275b27c3e5b3c9f2936d79400cdc5..bc218cbcea2a262a5d60f4aee37da817ba493605 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,7 +6,11 @@
       on postgres.  Also, the true labelname is always attached as the
       accessor on the parent Selectable so theres no need to be aware
       of the genrerated label names [ticket:512].
-      
+- orm:
+    - improved/fixed custom collection classes when giving it "set"/
+      "sets.Set" classes or subclasses (was still looking for append()
+      methods on them during lazy loads)
+
 0.3.6
 - sql:
     - bindparam() names are now repeatable!  specify two
index af3487dfdb7b909f0d8f53fb343a912c1349d86f..b7d2050ffb6f68b2a959b65fbf6ee094a8258aba 100644 (file)
@@ -162,15 +162,6 @@ class InstrumentedAttribute(object):
         else:
             return []
 
-    def _adapt_list(self, data):
-        if self.typecallable is not None:
-            t = self.typecallable()
-            if data is not None:
-                [t.append(x) for x in data]
-            return t
-        else:
-            return data
-
     def initialize(self, obj):
         """Initialize this attribute on the given object instance.
 
@@ -215,7 +206,7 @@ class InstrumentedAttribute(object):
                         return InstrumentedAttribute.PASSIVE_NORESULT
                     self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key))
                     values = callable_()
-                    l = InstrumentedList(self, obj, self._adapt_list(values), init=False)
+                    l = InstrumentedList(self, obj, values, init=False)
 
                     # if a callable was executed, then its part of the "committed state"
                     # if any, so commit the newly loaded data
@@ -362,6 +353,7 @@ class InstrumentedAttribute(object):
 
 InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute)
 
+    
 class InstrumentedList(object):
     """Instrument a list-based attribute.
 
@@ -388,21 +380,42 @@ class InstrumentedList(object):
         # and the list attribute, which interferes with immediate garbage collection.
         self.__obj = weakref.ref(obj)
         self.key = attr.key
-        self.data = data or attr._blank_list()
 
         # adapt to lists or sets
         # TODO: make three subclasses of InstrumentedList that come off from a
         # metaclass, based on the type of data sent in
-        if hasattr(self.data, 'append'):
+        if attr.typecallable is not None:
+            self.data = attr.typecallable()
+        else:
+            self.data = data or attr._blank_list()
+        
+        if isinstance(self.data, list):
             self._data_appender = self.data.append
             self._clear_data = self._clear_list
-        elif hasattr(self.data, 'add'):
+        elif isinstance(self.data, util.Set):
             self._data_appender = self.data.add
             self._clear_data = self._clear_set
-        else:
-            raise exceptions.ArgumentError("Collection type " + repr(type(self.data)) + " has no append() or add() method")
-        if isinstance(self.data, dict):
+        elif isinstance(self.data, dict):
+            if not hasattr(self.data, 'append'):
+                raise exceptions.ArgumentError("Dictionary collection class '%s' must implement an append() method" % type(self.data).__name__)
             self._clear_data = self._clear_dict
+        else:
+            if hasattr(self.data, 'append'):
+                self._data_appender = self.data.append
+            elif hasattr(self.data, 'add'):
+                self._data_appender = self.data.add
+            else:
+                raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no append() or add() method" % type(self.data).__name__)
+
+            if hasattr(self.data, 'clear'):
+                self._clear_data = self._clear_set
+            else:
+                raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no clear() method" % type(self.data).__name__)
+            
+        if data is not None and data is not self.data:
+            for elem in data:
+                self._data_appender(elem)
+                
 
         if init:
             for x in self.data:
@@ -475,7 +488,7 @@ class InstrumentedList(object):
         return repr(self.data)
 
     def __getattr__(self, attr):
-        """Proxie unknown methods and attributes to the underlying
+        """Proxy unknown methods and attributes to the underlying
         data array.  This allows custom list classes to be used.
         """
 
index 77486bb8efb4f83a3364d3d13fbb10dcb44ee1b2..7e0a22aff2b8b5025b90cdf73cb8349464db741c 100644 (file)
@@ -1,9 +1,10 @@
 from testbase import PersistTest
 import sqlalchemy.util as util
 import sqlalchemy.orm.attributes as attributes
+from sqlalchemy import exceptions
 import unittest, sys, os
 import pickle
-
+import testbase
 
 class MyTest(object):pass
 class MyTest2(object):pass
@@ -328,6 +329,49 @@ class AttributesTest(PersistTest):
 
         manager = attributes.AttributeManager()
         manager.reset_class_managed(Foo)
+    
+    def testcollectionclasses(self):
+        manager = attributes.AttributeManager()
+        class Foo(object):pass
+        manager.register_attribute(Foo, "collection", uselist=True, typecallable=set)
+        assert isinstance(Foo().collection.data, set)
         
+        manager.register_attribute(Foo, "collection", uselist=True, typecallable=dict)
+        try:
+            Foo().collection
+            assert False
+        except exceptions.ArgumentError, e:
+            assert str(e) == "Dictionary collection class 'dict' must implement an append() method"
+
+        class MyDict(dict):
+            def append(self, item):
+                self[item.foo] = item
+        manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyDict)
+        assert isinstance(Foo().collection.data, MyDict)
+        
+        class MyColl(object):pass
+        manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl)
+        try:
+            Foo().collection
+            assert False
+        except exceptions.ArgumentError, e:
+            assert str(e) == "Collection class 'MyColl' is not of type 'list', 'set', or 'dict' and has no append() or add() method"
+        
+        class MyColl(object):
+            def __iter__(self):
+                return iter([])
+            def append(self, item):
+                pass
+        manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl)
+        try:
+            Foo().collection
+            assert False
+        except exceptions.ArgumentError, e:
+            assert str(e) == "Collection class 'MyColl' is not of type 'list', 'set', or 'dict' and has no clear() method"
+
+        def foo(self):pass
+        MyColl.clear = foo
+        assert isinstance(Foo().collection.data, MyColl)
+            
 if __name__ == "__main__":
-    unittest.main()
+    testbase.main()
index c456120555dfba898d981fc750691c2d5e3398e2..7c2ec45728c1152f61476cad2fa926d9fe4ce7b9 100644 (file)
@@ -673,6 +673,50 @@ class TypeMatchTest(testbase.ORMTest):
         except exceptions.AssertionError, err:
             assert str(err) == "Attribute 'a' on class '%s' doesn't handle objects of type '%s'" % (D, B)
 
+class CustomCollectionsTest(testbase.ORMTest):
+    def define_tables(self, metadata):
+        global sometable, someothertable
+        sometable = Table('sometable', metadata,
+            Column('col1',Integer, primary_key=True),
+            Column('data', String(30)))
+        someothertable = Table('someothertable', metadata, 
+            Column('col1', Integer, primary_key=True),
+            Column('scol1', Integer, ForeignKey(sometable.c.col1)),
+            Column('data', String(20))
+        )
+    def testbasic(self):
+        class MyList(list):
+            pass
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+        mapper(Foo, sometable, properties={
+            'bars':relation(Bar, collection_class=MyList)
+        })
+        mapper(Bar, someothertable)
+        f = Foo()
+        assert isinstance(f.bars.data, MyList)
+    def testlazyload(self):
+        """test that a 'set' can be used as a collection and can lazyload."""
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+        mapper(Foo, sometable, properties={
+            'bars':relation(Bar, collection_class=set)
+        })
+        mapper(Bar, someothertable)
+        f = Foo()
+        f.bars.add(Bar())
+        f.bars.add(Bar())
+        sess = create_session()
+        sess.save(f)
+        sess.flush()
+        sess.clear()
+        f = sess.query(Foo).get(f.col1)
+        assert len(list(f.bars)) == 2
+
 class ViewOnlyTest(testbase.ORMTest):
     """test a view_only mapping where a third table is pulled into the primary join condition,
     using overlapping PK column names (should not produce "conflicting column" error)"""
index 598fba78040be6f94ac76afa3b78931427ea6978..fd97c9c1a9d60e38651f9556b80a9917453956cb 100644 (file)
@@ -55,33 +55,6 @@ class HistoryTest(UnitOfWorkTest):
         u = s.query(m).select()[0]
         print u.addresses[0].user
 
-class CustomCollectionsTest(UnitOfWorkTest):
-    def setUpAll(self):
-        UnitOfWorkTest.setUpAll(self)
-        global sometable, metadata, someothertable
-        metadata = BoundMetaData(testbase.db)
-        sometable = Table('sometable', metadata,
-            Column('col1',Integer, primary_key=True))
-        someothertable = Table('someothertable', metadata, 
-            Column('col1', Integer, primary_key=True),
-            Column('scol1', Integer, ForeignKey(sometable.c.col1)),
-            Column('data', String(20))
-        )
-    def testbasic(self):
-        class MyList(list):
-            pass
-        class Foo(object):
-            pass
-        class Bar(object):
-            pass
-        mapper(Foo, sometable, properties={
-            'bars':relation(Bar, collection_class=MyList)
-        })
-        mapper(Bar, someothertable)
-        f = Foo()
-        assert isinstance(f.bars.data, MyList)
-    def tearDownAll(self):
-        UnitOfWorkTest.tearDownAll(self)
             
 class VersioningTest(UnitOfWorkTest):
     def setUpAll(self):