From: Jason Kirtland Date: Fri, 14 Dec 2007 00:13:18 +0000 (+0000) Subject: - Raise an error when assigning a bogusly keyed dictionary to one of the builtin... X-Git-Tag: rel_0_4_2~54 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=273e48c9a95825541bd461a1d5402f2e65f95876;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Raise an error when assigning a bogusly keyed dictionary to one of the builtin dict-based collection types [ticket:886] - Collections gain a @converter framework for flexible validation and adaptation of bulk assignment - Bogus bulk assignments now raise TypeError instead of exceptions.ArgumentError --- diff --git a/CHANGES b/CHANGES index dcc6cbdb2f..e970b0eb05 100644 --- a/CHANGES +++ b/CHANGES @@ -138,7 +138,24 @@ CHANGES mapper.pks_by_table, mapper.cascade_callable(), MapperProperty.cascade_callable(), mapper.canload(), mapper._mapper_registry, attributes.AttributeManager - + + - Assigning an incompatible collection type to a relation attribute now + raises TypeError instead of sqlalchemy's ArgumentError. + + - Bulk assignment of a MappedCollection now raises an error if a key in the + incoming dictionary does not match the key that the collection's keyfunc + would use for that value. [ticket:886] + + - Custom collections can now specify a @converter method to translate + objects used in "bulk" assignment into a stream of values, as in:: + + obj.col = [newval1, newval2] + # or + obj.dictcol = {'foo': newval1, 'bar': newval2} + + The MappedCollection uses this hook to ensure that incoming key/value + pairs are sane from the collection's perspective. + - fixed endless loop issue when using lazy="dynamic" on both sides of a bi-directional relationship [ticket:872] @@ -171,7 +188,7 @@ CHANGES - fixed bug which could arise when using session.begin_nested() in conjunction with more than one level deep of enclosing session.begin() statements - + - dialects - MSSQL/PyODBC no longer has a global "set nocount on". diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index a8699f832a..a26bc2b58d 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -476,29 +476,22 @@ class CollectionAttributeImpl(AttributeImpl): if initiator is self: return - setting_type = util.duck_type_collection(value) - - if value is None or setting_type != self.collection_interface: - raise exceptions.ArgumentError( - "Incompatible collection type on assignment: %s is not %s-like" % - (type(value).__name__, self.collection_interface.__name__)) - - if hasattr(value, '_sa_adapter'): - value = list(getattr(value, '_sa_adapter')) - elif setting_type == dict: - value = value.values() + # we need a CollectionAdapter to adapt the incoming value to an + # assignable iterable. pulling a new collection first so that + # an adaptation exception does not trigger a lazy load of the + # old collection. + new_collection, user_data = self._build_collection(state) + new_values = list(new_collection.adapt_like_to_iterable(value)) old = self.get(state) old_collection = self.get_collection(state, old) - - new_collection, user_data = self._build_collection(state) idset = util.IdentitySet - constants = idset(old_collection or []).intersection(value or []) - additions = idset(value or []).difference(constants) + constants = idset(old_collection or []).intersection(new_values or []) + additions = idset(new_values or []).difference(constants) removals = idset(old_collection or []).difference(constants) - for member in value or []: + for member in new_values or (): if member in additions: new_collection.append_with_event(member) elif member in constants: diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index c2cd4cf09d..7334e46642 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -15,11 +15,11 @@ and return values to events:: from sqlalchemy.orm.collections import collection class MyClass(object): # ... - + @collection.adds(1) def store(self, item): self.data.append(item) - + @collection.removes_return() def pop(self): return self.data.pop() @@ -104,7 +104,7 @@ from sqlalchemy.util import attrgetter __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', 'attribute_mapped_collection'] - + def column_mapped_collection(mapping_spec): """A dictionary-based collection type with column-based keying. @@ -193,7 +193,7 @@ class collection(object): # Bundled as a class solely for ease of use: packaging, doc strings, # importability. - + def appender(cls, fn): """Tag the method as the collection appender. @@ -228,7 +228,7 @@ class collection(object): database contains rows that violate your collection semantics, you will need to get creative to fix the problem, as access via the collection will not work. - + If the appender method is internally instrumented, you must also receive the keyword argument '_sa_initiator' and ensure its promulgation to collection events. @@ -260,7 +260,7 @@ class collection(object): receive the keyword argument '_sa_initiator' and ensure its promulgation to collection events. """ - + setattr(fn, '_sa_instrument_role', 'remover') return fn remover = classmethod(remover) @@ -294,7 +294,7 @@ class collection(object): @collection.internally_instrumented def extend(self, items): ... """ - + setattr(fn, '_sa_instrumented', True) return fn internally_instrumented = classmethod(internally_instrumented) @@ -308,11 +308,44 @@ class collection(object): the instance. A single argument is passed: the collection adapter that has been linked, or None if unlinking. """ - + setattr(fn, '_sa_instrument_role', 'on_link') return fn on_link = classmethod(on_link) + def converter(cls, fn): + """Tag the method as the collection converter. + + This optional method will be called when a collection is being + replaced entirely, as in:: + + myobj.acollection = [newvalue1, newvalue2] + + The converter method will receive the object being assigned and should + return an iterable of values suitable for use by the ``appender`` + method. A converter must not assign values or mutate the collection, + it's sole job is to adapt the value the user provides into an iterable + of values for the ORM's use. + + The default converter implementation will use duck-typing to do the + conversion. A dict-like collection will be convert into an iterable + of dictionary values, and other types will simply be iterated. + + @collection.converter + def convert(self, other): ... + + If the duck-typing of the object does not match the type of this + collection, a TypeError is raised. + + Supply an implementation of this method if you want to expand the + range of possible types that can be assigned in bulk or perform + validation on the values about to be assigned. + """ + + setattr(fn, '_sa_instrument_role', 'converter') + return fn + converter = classmethod(converter) + def adds(cls, arg): """Mark the method as adding an entity to the collection. @@ -340,13 +373,13 @@ class collection(object): the method. The decorator argument indicates which method argument holds the SQLAlchemy-relevant value to be added, and return value, if any will be considered the value to remove. - + Arguments can be specified positionally (i.e. integer) or by name:: @collection.replaces(2) def __setitem__(self, index, item): ... """ - + def decorator(fn): setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) setattr(fn, '_sa_instrument_after', 'fire_remove_event') @@ -374,7 +407,7 @@ class collection(object): return fn return decorator removes = classmethod(removes) - + def removes_return(cls): """Mark the method as removing an entity in the collection. @@ -417,7 +450,7 @@ def collection_iter(collection): raise TypeError("'%s' object is not iterable" % type(collection).__name__) - + class CollectionAdapter(object): """Bridges between the ORM and arbitrary Python collections. @@ -452,6 +485,39 @@ class CollectionAdapter(object): if hasattr(data, '_sa_on_link'): getattr(data, '_sa_on_link')(None) + def adapt_like_to_iterable(self, obj): + """Converts collection-compatible objects to an iterable of values. + + Can be passed any type of object, and if the underlying collection + determines that it can be adapted into a stream of values it can + use, returns an iterable of values suitable for append()ing. + + This method may raise TypeError or any other suitable exception + if adaptation fails. + + If a converter implementation is not supplied on the collection, + a default duck-typing-based implementation is used. + """ + + converter = getattr(self._data(), '_sa_converter', None) + if converter is not None: + return converter(obj) + + setting_type = sautil.duck_type_collection(obj) + + if obj is None or setting_type != self.attr.collection_interface: + raise TypeError( + "Incompatible collection type: %s is not %s-like" % + (type(obj).__name__, self.attr.collection_interface.__name__)) + + # If the object is an adapted collection, return the (iterable) adapter. + if getattr(obj, '_sa_adapter', None) is not None: + return getattr(obj, '_sa_adapter') + elif setting_type == dict: + return getattr(obj, 'itervalues', getattr(obj, 'values'))() + else: + return iter(obj) + def append_with_event(self, item, initiator=None): """Add an entity to the collection, firing mutation events.""" @@ -504,7 +570,7 @@ class CollectionAdapter(object): mutation, and should be left as None unless you are passing along an initiator value from a chained operation. """ - + if initiator is not False and item is not None: self.attr.fire_append_event(self.owner_state, item, initiator) @@ -518,7 +584,7 @@ class CollectionAdapter(object): if initiator is not False and item is not None: self.attr.fire_remove_event(self.owner_state, item, initiator) - + def __getstate__(self): return { 'key': self.attr.key, 'owner_state': self.owner_state, @@ -598,7 +664,7 @@ def _instrument_class(cls): # FIXME: more formally document this as a decoratorless/Python 2.3 # option for specifying instrumentation. (likely doc'd here in code only, # not in online docs.) - # + # # __instrumentation__ = { # 'rolename': 'methodname', # ... # 'methods': { @@ -617,7 +683,7 @@ def _instrument_class(cls): raise exceptions.ArgumentError( "Can not instrument a built-in type. Use a " "subclass, even a trivial one.") - + collection_type = sautil.duck_type_collection(cls) if collection_type in __interfaces: roles = __interfaces[collection_type].copy() @@ -638,7 +704,8 @@ def _instrument_class(cls): # note role declarations if hasattr(method, '_sa_instrument_role'): role = method._sa_instrument_role - assert role in ('appender', 'remover', 'iterator', 'on_link') + assert role in ('appender', 'remover', 'iterator', + 'on_link', 'converter') roles[role] = name # transfer instrumentation requests from decorated function @@ -691,7 +758,7 @@ def _instrument_class(cls): for method, (before, argument, after) in methods.items(): setattr(cls, method, _instrument_membership_mutator(getattr(cls, method), - before, argument, after)) + before, argument, after)) # intern the role map for role, method in roles.items(): setattr(cls, '_sa_%s' % role, getattr(cls, method)) @@ -736,7 +803,7 @@ def _instrument_membership_mutator(method, before, argument, after): executor = None else: executor = getattr(args[0], '_sa_adapter', None) - + if before and executor: getattr(executor, before)(value, initiator) @@ -762,7 +829,7 @@ def __set(collection, item, _sa_initiator=None): executor = getattr(collection, '_sa_adapter', None) if executor: getattr(executor, 'fire_append_event')(item, _sa_initiator) - + def __del(collection, item, _sa_initiator=None): """Run del events, may eventually be inlined into decorators.""" @@ -770,11 +837,11 @@ def __del(collection, item, _sa_initiator=None): executor = getattr(collection, '_sa_adapter', None) if executor: getattr(executor, 'fire_remove_event')(item, _sa_initiator) - + def _list_decorators(): """Hand-turned instrumentation wrappers that can decorate any list-like class.""" - + def _tidy(fn): setattr(fn, '_sa_instrumented', True) fn.__doc__ = getattr(getattr(list, fn.__name__), '__doc__') @@ -868,7 +935,7 @@ def _list_decorators(): fn(self, start, end, values) _tidy(__setslice__) return __setslice__ - + def __delslice__(fn): def __delslice__(self, start, end): for value in self[start:end]: @@ -883,7 +950,7 @@ def _list_decorators(): self.append(value) _tidy(extend) return extend - + def pop(fn): def pop(self, index=-1): item = fn(self, index) @@ -1101,7 +1168,7 @@ class InstrumentedList(list): 'remover': 'remove', 'iterator': '__iter__', } -class InstrumentedSet(sautil.Set): +class InstrumentedSet(sautil.Set): """An instrumented version of the built-in set (or Set).""" __instrumentation__ = { @@ -1109,7 +1176,7 @@ class InstrumentedSet(sautil.Set): 'remover': 'remove', 'iterator': '__iter__', } -class InstrumentedDict(dict): +class InstrumentedDict(dict): """An instrumented version of the built-in dict.""" __instrumentation__ = { @@ -1146,7 +1213,7 @@ class MappedCollection(dict): callable that takes an object and returns an object for use as a dictionary key. """ - + def __init__(self, keyfunc): """Create a new collection with keying provided by keyfunc. @@ -1169,10 +1236,10 @@ class MappedCollection(dict): self.__setitem__(key, value, _sa_initiator) set = collection.internally_instrumented(set) set = collection.appender(set) - + def remove(self, value, _sa_initiator=None): """Remove an item from the collection by value, consulting this instance's keyfunc for the key.""" - + key = self.keyfunc(value) # Let self[key] raise if key is not in this collection if self[key] != value: @@ -1185,3 +1252,26 @@ class MappedCollection(dict): self.__delitem__(key, _sa_initiator) remove = collection.internally_instrumented(remove) remove = collection.remover(remove) + + def _convert(self, dictlike): + """Validate and convert a dict-like object into values for set()ing. + + This is called behind the scenes when a MappedCollection is replaced + entirely by another collection, as in:: + + myobj.mappedcollection = {'a':obj1, 'b': obj2} # ... + + Raises a TypeError if the key in any (key, value) pair in the dictlike + object does not match the key that this collection's keyfunc would + have assigned for that value. + """ + + for incoming_key, value in sautil.dictlike_iteritems(dictlike): + new_key = self.keyfunc(value) + if incoming_key != new_key: + raise TypeError( + "Found incompatible key %r for value %r; this collection's " + "keying function requires a key of %r for this value." % ( + incoming_key, value, new_key)) + yield value + _convert = collection.converter(_convert) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 1fee4cef0a..705168d209 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -193,6 +193,30 @@ def duck_type_collection(specimen, default=None): else: return default +def dictlike_iteritems(dictlike): + """Return a (key, value) iterator for almost any dict-like object.""" + + if hasattr(dictlike, 'iteritems'): + return dictlike.iteritems() + elif hasattr(dictlike, 'items'): + return iter(dictlike.items()) + + getter = getattr(dictlike, '__getitem__', getattr(dictlike, 'get', None)) + if getter is None: + raise TypeError( + "Object '%r' is not dict-like" % dictlike) + + if hasattr(dictlike, 'iterkeys'): + def iterator(): + for key in dictlike.iterkeys(): + yield key, getter(key) + return iterator() + elif hasattr(dictlike, 'keys'): + return iter([(key, getter(key)) for key in dictlike.keys()]) + else: + raise TypeError( + "Object '%r' is not dict-like" % dictlike) + def assert_arg_type(arg, argtype, name): if isinstance(arg, argtype): return arg diff --git a/test/base/utils.py b/test/base/utils.py index 1cfcd8fb5a..932ad876a2 100644 --- a/test/base/utils.py +++ b/test/base/utils.py @@ -14,7 +14,7 @@ class OrderedDictTest(PersistTest): self.assert_(o.keys() == ['a', 'b', 'snack', 'c']) self.assert_(o.values() == [1, 2, 'attack', 3]) - + o.pop('snack') self.assert_(o.keys() == ['a', 'b', 'c']) @@ -49,7 +49,7 @@ class ColumnCollectionTest(PersistTest): assert False except exceptions.ArgumentError, e: assert str(e) == "__contains__ requires a string argument" - + def test_compare(self): cc1 = sql.ColumnCollection() cc2 = sql.ColumnCollection() @@ -66,24 +66,25 @@ class ColumnCollectionTest(PersistTest): class ArgSingletonTest(unittest.TestCase): def test_cleanout(self): util.ArgSingleton.instances.clear() - + class MyClass(object): __metaclass__ = util.ArgSingleton def __init__(self, x, y): self.x = x self.y = y - + m1 = MyClass(3, 4) m2 = MyClass(1, 5) m3 = MyClass(3, 4) assert m1 is m3 assert m2 is not m3 assert len(util.ArgSingleton.instances) == 2 - + m1 = m2 = m3 = None MyClass.dispose(MyClass) assert len(util.ArgSingleton.instances) == 0 + class ImmutableSubclass(str): pass @@ -222,5 +223,76 @@ class IdentitySetTest(unittest.TestCase): self.assertRaises(TypeError, hash, ids) +class DictlikeIteritemsTest(unittest.TestCase): + baseline = set([('a', 1), ('b', 2), ('c', 3)]) + + def _ok(self, instance): + iterator = util.dictlike_iteritems(instance) + self.assertEquals(set(iterator), self.baseline) + + def _notok(self, instance): + self.assertRaises(TypeError, + util.dictlike_iteritems, + instance) + + def test_dict(self): + d = dict(a=1,b=2,c=3) + self._ok(d) + + def test_subdict(self): + class subdict(dict): + pass + d = subdict(a=1,b=2,c=3) + self._ok(d) + + def test_UserDict(self): + import UserDict + d = UserDict.UserDict(a=1,b=2,c=3) + self._ok(d) + + def test_object(self): + self._notok(object()) + + def test_duck_1(self): + class duck1(object): + def iteritems(duck): + return iter(self.baseline) + self._ok(duck1()) + + def test_duck_2(self): + class duck2(object): + def items(duck): + return list(self.baseline) + self._ok(duck2()) + + def test_duck_3(self): + class duck3(object): + def iterkeys(duck): + return iter(['a', 'b', 'c']) + def __getitem__(duck, key): + return dict(a=1,b=2,c=3).get(key) + self._ok(duck3()) + + def test_duck_4(self): + class duck4(object): + def iterkeys(duck): + return iter(['a', 'b', 'c']) + self._notok(duck4()) + + def test_duck_5(self): + class duck5(object): + def keys(duck): + return ['a', 'b', 'c'] + def get(duck, key): + return dict(a=1,b=2,c=3).get(key) + self._ok(duck5()) + + def test_duck_6(self): + class duck6(object): + def keys(duck): + return ['a', 'b', 'c'] + self._notok(duck6()) + + if __name__ == "__main__": testbase.main() diff --git a/test/ext/associationproxy.py b/test/ext/associationproxy.py index 2accd2fc83..fe8b40255c 100644 --- a/test/ext/associationproxy.py +++ b/test/ext/associationproxy.py @@ -288,13 +288,13 @@ class CustomDictTest(DictTest): try: p1._children = [] self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(True) try: p1._children = None self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(True) self.assertRaises(TypeError, set, [p1.children]) @@ -404,13 +404,13 @@ class SetTest(_CollectionOperations): try: p1._children = [] self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(True) try: p1._children = None self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(True) self.assertRaises(TypeError, set, [p1.children]) diff --git a/test/orm/collection.py b/test/orm/collection.py index 5d1753909a..43b2f41e25 100644 --- a/test/orm/collection.py +++ b/test/orm/collection.py @@ -269,7 +269,7 @@ class CollectionsTest(PersistTest): try: obj.attr = set([e4]) self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(e4 not in canary.data) self.assert_(e3 in canary.data) @@ -526,7 +526,7 @@ class CollectionsTest(PersistTest): try: obj.attr = [e4] self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(e4 not in canary.data) self.assert_(e3 in canary.data) @@ -737,23 +737,42 @@ class CollectionsTest(PersistTest): self.assert_(e1 in canary.removed) self.assert_(e2 in canary.added) + + # key validity on bulk assignment is a basic feature of MappedCollection + # but is not present in basic, @converter-less dict collections. e3 = creator() - real_dict = dict(keyignored1=e3) - obj.attr = real_dict - self.assert_(obj.attr is not real_dict) - self.assert_('keyignored1' not in obj.attr) - self.assert_(set(collections.collection_adapter(obj.attr)) == set([e3])) - self.assert_(e2 in canary.removed) - self.assert_(e3 in canary.added) + if isinstance(obj.attr, collections.MappedCollection): + real_dict = dict(badkey=e3) + try: + obj.attr = real_dict + self.assert_(False) + except TypeError: + pass + self.assert_(obj.attr is not real_dict) + self.assert_('badkey' not in obj.attr) + self.assertEquals(set(collections.collection_adapter(obj.attr)), + set([e2])) + self.assert_(e3 not in canary.added) + else: + real_dict = dict(keyignored1=e3) + obj.attr = real_dict + self.assert_(obj.attr is not real_dict) + self.assert_('keyignored1' not in obj.attr) + self.assertEquals(set(collections.collection_adapter(obj.attr)), + set([e3])) + self.assert_(e2 in canary.removed) + self.assert_(e3 in canary.added) + + obj.attr = typecallable() + self.assertEquals(list(collections.collection_adapter(obj.attr)), []) e4 = creator() try: obj.attr = [e4] self.assert_(False) - except exceptions.ArgumentError: + except TypeError: self.assert_(e4 not in canary.data) - self.assert_(e3 in canary.data) - + def test_dict(self): try: self._test_adapter(dict, dictable_entity,