From: Mike Bayer Date: Mon, 26 Mar 2007 19:59:39 +0000 (+0000) Subject: - improved/fixed custom collection classes when giving it "set"/ X-Git-Tag: rel_0_3_7~108 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=7300b198b94a2ebc0d1ae939d8d632866df87340;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - improved/fixed custom collection classes when giving it "set"/ "sets.Set" classes or subclasses (was still looking for append() methods on them during lazy loads) - moved CustomCollectionsTest from unitofwork to relationships - added more custom collections test to attributes module --- diff --git a/CHANGES b/CHANGES index 71843f3600..bc218cbcea 100644 --- a/CHANGES +++ b/CHANGES @@ -6,7 +6,11 @@ on postgres. Also, the true labelname is always attached as the accessor on the parent Selectable so theres no need to be aware of the genrerated label names [ticket:512]. - +- orm: + - improved/fixed custom collection classes when giving it "set"/ + "sets.Set" classes or subclasses (was still looking for append() + methods on them during lazy loads) + 0.3.6 - sql: - bindparam() names are now repeatable! specify two diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index af3487dfdb..b7d2050ffb 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -162,15 +162,6 @@ class InstrumentedAttribute(object): else: return [] - def _adapt_list(self, data): - if self.typecallable is not None: - t = self.typecallable() - if data is not None: - [t.append(x) for x in data] - return t - else: - return data - def initialize(self, obj): """Initialize this attribute on the given object instance. @@ -215,7 +206,7 @@ class InstrumentedAttribute(object): return InstrumentedAttribute.PASSIVE_NORESULT self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key)) values = callable_() - l = InstrumentedList(self, obj, self._adapt_list(values), init=False) + l = InstrumentedList(self, obj, values, init=False) # if a callable was executed, then its part of the "committed state" # if any, so commit the newly loaded data @@ -362,6 +353,7 @@ class InstrumentedAttribute(object): InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute) + class InstrumentedList(object): """Instrument a list-based attribute. @@ -388,21 +380,42 @@ class InstrumentedList(object): # and the list attribute, which interferes with immediate garbage collection. self.__obj = weakref.ref(obj) self.key = attr.key - self.data = data or attr._blank_list() # adapt to lists or sets # TODO: make three subclasses of InstrumentedList that come off from a # metaclass, based on the type of data sent in - if hasattr(self.data, 'append'): + if attr.typecallable is not None: + self.data = attr.typecallable() + else: + self.data = data or attr._blank_list() + + if isinstance(self.data, list): self._data_appender = self.data.append self._clear_data = self._clear_list - elif hasattr(self.data, 'add'): + elif isinstance(self.data, util.Set): self._data_appender = self.data.add self._clear_data = self._clear_set - else: - raise exceptions.ArgumentError("Collection type " + repr(type(self.data)) + " has no append() or add() method") - if isinstance(self.data, dict): + elif isinstance(self.data, dict): + if not hasattr(self.data, 'append'): + raise exceptions.ArgumentError("Dictionary collection class '%s' must implement an append() method" % type(self.data).__name__) self._clear_data = self._clear_dict + else: + if hasattr(self.data, 'append'): + self._data_appender = self.data.append + elif hasattr(self.data, 'add'): + self._data_appender = self.data.add + else: + raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no append() or add() method" % type(self.data).__name__) + + if hasattr(self.data, 'clear'): + self._clear_data = self._clear_set + else: + raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no clear() method" % type(self.data).__name__) + + if data is not None and data is not self.data: + for elem in data: + self._data_appender(elem) + if init: for x in self.data: @@ -475,7 +488,7 @@ class InstrumentedList(object): return repr(self.data) def __getattr__(self, attr): - """Proxie unknown methods and attributes to the underlying + """Proxy unknown methods and attributes to the underlying data array. This allows custom list classes to be used. """ diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 77486bb8ef..7e0a22aff2 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -1,9 +1,10 @@ from testbase import PersistTest import sqlalchemy.util as util import sqlalchemy.orm.attributes as attributes +from sqlalchemy import exceptions import unittest, sys, os import pickle - +import testbase class MyTest(object):pass class MyTest2(object):pass @@ -328,6 +329,49 @@ class AttributesTest(PersistTest): manager = attributes.AttributeManager() manager.reset_class_managed(Foo) + + def testcollectionclasses(self): + manager = attributes.AttributeManager() + class Foo(object):pass + manager.register_attribute(Foo, "collection", uselist=True, typecallable=set) + assert isinstance(Foo().collection.data, set) + manager.register_attribute(Foo, "collection", uselist=True, typecallable=dict) + try: + Foo().collection + assert False + except exceptions.ArgumentError, e: + assert str(e) == "Dictionary collection class 'dict' must implement an append() method" + + class MyDict(dict): + def append(self, item): + self[item.foo] = item + manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyDict) + assert isinstance(Foo().collection.data, MyDict) + + class MyColl(object):pass + manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl) + try: + Foo().collection + assert False + except exceptions.ArgumentError, e: + assert str(e) == "Collection class 'MyColl' is not of type 'list', 'set', or 'dict' and has no append() or add() method" + + class MyColl(object): + def __iter__(self): + return iter([]) + def append(self, item): + pass + manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl) + try: + Foo().collection + assert False + except exceptions.ArgumentError, e: + assert str(e) == "Collection class 'MyColl' is not of type 'list', 'set', or 'dict' and has no clear() method" + + def foo(self):pass + MyColl.clear = foo + assert isinstance(Foo().collection.data, MyColl) + if __name__ == "__main__": - unittest.main() + testbase.main() diff --git a/test/orm/relationships.py b/test/orm/relationships.py index c456120555..7c2ec45728 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -673,6 +673,50 @@ class TypeMatchTest(testbase.ORMTest): except exceptions.AssertionError, err: assert str(err) == "Attribute 'a' on class '%s' doesn't handle objects of type '%s'" % (D, B) +class CustomCollectionsTest(testbase.ORMTest): + def define_tables(self, metadata): + global sometable, someothertable + sometable = Table('sometable', metadata, + Column('col1',Integer, primary_key=True), + Column('data', String(30))) + someothertable = Table('someothertable', metadata, + Column('col1', Integer, primary_key=True), + Column('scol1', Integer, ForeignKey(sometable.c.col1)), + Column('data', String(20)) + ) + def testbasic(self): + class MyList(list): + pass + class Foo(object): + pass + class Bar(object): + pass + mapper(Foo, sometable, properties={ + 'bars':relation(Bar, collection_class=MyList) + }) + mapper(Bar, someothertable) + f = Foo() + assert isinstance(f.bars.data, MyList) + def testlazyload(self): + """test that a 'set' can be used as a collection and can lazyload.""" + class Foo(object): + pass + class Bar(object): + pass + mapper(Foo, sometable, properties={ + 'bars':relation(Bar, collection_class=set) + }) + mapper(Bar, someothertable) + f = Foo() + f.bars.add(Bar()) + f.bars.add(Bar()) + sess = create_session() + sess.save(f) + sess.flush() + sess.clear() + f = sess.query(Foo).get(f.col1) + assert len(list(f.bars)) == 2 + class ViewOnlyTest(testbase.ORMTest): """test a view_only mapping where a third table is pulled into the primary join condition, using overlapping PK column names (should not produce "conflicting column" error)""" diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 598fba7804..fd97c9c1a9 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -55,33 +55,6 @@ class HistoryTest(UnitOfWorkTest): u = s.query(m).select()[0] print u.addresses[0].user -class CustomCollectionsTest(UnitOfWorkTest): - def setUpAll(self): - UnitOfWorkTest.setUpAll(self) - global sometable, metadata, someothertable - metadata = BoundMetaData(testbase.db) - sometable = Table('sometable', metadata, - Column('col1',Integer, primary_key=True)) - someothertable = Table('someothertable', metadata, - Column('col1', Integer, primary_key=True), - Column('scol1', Integer, ForeignKey(sometable.c.col1)), - Column('data', String(20)) - ) - def testbasic(self): - class MyList(list): - pass - class Foo(object): - pass - class Bar(object): - pass - mapper(Foo, sometable, properties={ - 'bars':relation(Bar, collection_class=MyList) - }) - mapper(Bar, someothertable) - f = Foo() - assert isinstance(f.bars.data, MyList) - def tearDownAll(self): - UnitOfWorkTest.tearDownAll(self) class VersioningTest(UnitOfWorkTest): def setUpAll(self):