]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Raise an error when assigning a bogusly keyed dictionary to one of the builtin...
authorJason Kirtland <jek@discorporate.us>
Fri, 14 Dec 2007 00:13:18 +0000 (00:13 +0000)
committerJason Kirtland <jek@discorporate.us>
Fri, 14 Dec 2007 00:13:18 +0000 (00:13 +0000)
- Collections gain a @converter framework for flexible validation and adaptation of bulk assignment
- Bogus bulk assignments now raise TypeError instead of exceptions.ArgumentError

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/util.py
test/base/utils.py
test/ext/associationproxy.py
test/orm/collection.py

diff --git a/CHANGES b/CHANGES
index dcc6cbdb2fcd213b5f9392c6e3fc9bba5399301f..e970b0eb0570b0c919a72dc24e7ac91dd6549b2f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -138,7 +138,24 @@ CHANGES
      mapper.pks_by_table, mapper.cascade_callable(), 
      MapperProperty.cascade_callable(), mapper.canload(),
      mapper._mapper_registry, attributes.AttributeManager
-     
+
+   - Assigning an incompatible collection type to a relation attribute now
+     raises TypeError instead of sqlalchemy's ArgumentError.
+
+   - Bulk assignment of a MappedCollection now raises an error if a key in the
+     incoming dictionary does not match the key that the collection's keyfunc
+     would use for that value. [ticket:886]
+
+   - Custom collections can now specify a @converter method to translate
+     objects used in "bulk" assignment into a stream of values, as in::
+
+        obj.col = [newval1, newval2]
+        # or
+        obj.dictcol = {'foo': newval1, 'bar': newval2}
+
+     The MappedCollection uses this hook to ensure that incoming key/value
+     pairs are sane from the collection's perspective.
+
    - fixed endless loop issue when using lazy="dynamic" on both 
      sides of a bi-directional relationship [ticket:872]
 
@@ -171,7 +188,7 @@ CHANGES
 
    - fixed bug which could arise when using session.begin_nested() in conjunction
      with more than one level deep of enclosing session.begin() statements
-     
+
 - dialects
 
    - MSSQL/PyODBC no longer has a global "set nocount on".
index a8699f832a4bebf64280159cdb808138318cc7fd..a26bc2b58de776d2b8c7476251b39007212f65a4 100644 (file)
@@ -476,29 +476,22 @@ class CollectionAttributeImpl(AttributeImpl):
         if initiator is self:
             return
 
-        setting_type = util.duck_type_collection(value)
-
-        if value is None or setting_type != self.collection_interface:
-            raise exceptions.ArgumentError(
-                "Incompatible collection type on assignment: %s is not %s-like" %
-                (type(value).__name__, self.collection_interface.__name__))
-
-        if hasattr(value, '_sa_adapter'):
-            value = list(getattr(value, '_sa_adapter'))
-        elif setting_type == dict:
-            value = value.values()
+        # we need a CollectionAdapter to adapt the incoming value to an
+        # assignable iterable.  pulling a new collection first so that
+        # an adaptation exception does not trigger a lazy load of the
+        # old collection.
+        new_collection, user_data = self._build_collection(state)
+        new_values = list(new_collection.adapt_like_to_iterable(value))
 
         old = self.get(state)
         old_collection = self.get_collection(state, old)
-        
-        new_collection, user_data = self._build_collection(state)
 
         idset = util.IdentitySet
-        constants = idset(old_collection or []).intersection(value or [])
-        additions = idset(value or []).difference(constants)
+        constants = idset(old_collection or []).intersection(new_values or [])
+        additions = idset(new_values or []).difference(constants)
         removals  = idset(old_collection or []).difference(constants)
 
-        for member in value or []:
+        for member in new_values or ():
             if member in additions:
                 new_collection.append_with_event(member)
             elif member in constants:
index c2cd4cf09dfe1ca7024492fd46a708d297cdd931..7334e466421b6fdecd8c797008193a81a0b54f51 100644 (file)
@@ -15,11 +15,11 @@ and return values to events::
   from sqlalchemy.orm.collections import collection
   class MyClass(object):
       # ...
-      
+
       @collection.adds(1)
       def store(self, item):
           self.data.append(item)
-      
+
       @collection.removes_return()
       def pop(self):
           return self.data.pop()
@@ -104,7 +104,7 @@ from sqlalchemy.util import attrgetter
 __all__ = ['collection', 'collection_adapter',
            'mapped_collection', 'column_mapped_collection',
            'attribute_mapped_collection']
-           
+
 def column_mapped_collection(mapping_spec):
     """A dictionary-based collection type with column-based keying.
 
@@ -193,7 +193,7 @@ class collection(object):
 
     # Bundled as a class solely for ease of use: packaging, doc strings,
     # importability.
-    
+
     def appender(cls, fn):
         """Tag the method as the collection appender.
 
@@ -228,7 +228,7 @@ class collection(object):
         database contains rows that violate your collection semantics, you
         will need to get creative to fix the problem, as access via the
         collection will not work.
-     
+
         If the appender method is internally instrumented, you must also
         receive the keyword argument '_sa_initiator' and ensure its
         promulgation to collection events.
@@ -260,7 +260,7 @@ class collection(object):
         receive the keyword argument '_sa_initiator' and ensure its
         promulgation to collection events.
         """
-        
+
         setattr(fn, '_sa_instrument_role', 'remover')
         return fn
     remover = classmethod(remover)
@@ -294,7 +294,7 @@ class collection(object):
             @collection.internally_instrumented
             def extend(self, items): ...
         """
-        
+
         setattr(fn, '_sa_instrumented', True)
         return fn
     internally_instrumented = classmethod(internally_instrumented)
@@ -308,11 +308,44 @@ class collection(object):
         the instance.  A single argument is passed: the collection adapter
         that has been linked, or None if unlinking.
         """
-        
+
         setattr(fn, '_sa_instrument_role', 'on_link')
         return fn
     on_link = classmethod(on_link)
 
+    def converter(cls, fn):
+        """Tag the method as the collection converter.
+
+        This optional method will be called when a collection is being
+        replaced entirely, as in::
+
+            myobj.acollection = [newvalue1, newvalue2]
+
+        The converter method will receive the object being assigned and should
+        return an iterable of values suitable for use by the ``appender``
+        method.  A converter must not assign values or mutate the collection,
+        it's sole job is to adapt the value the user provides into an iterable
+        of values for the ORM's use.
+
+        The default converter implementation will use duck-typing to do the
+        conversion.  A dict-like collection will be convert into an iterable
+        of dictionary values, and other types will simply be iterated.
+
+            @collection.converter
+            def convert(self, other): ...
+
+        If the duck-typing of the object does not match the type of this
+        collection, a TypeError is raised.
+
+        Supply an implementation of this method if you want to expand the
+        range of possible types that can be assigned in bulk or perform
+        validation on the values about to be assigned.
+        """
+
+        setattr(fn, '_sa_instrument_role', 'converter')
+        return fn
+    converter = classmethod(converter)
+
     def adds(cls, arg):
         """Mark the method as adding an entity to the collection.
 
@@ -340,13 +373,13 @@ class collection(object):
         the method.  The decorator argument indicates which method argument
         holds the SQLAlchemy-relevant value to be added, and return value, if
         any will be considered the value to remove.
-        
+
         Arguments can be specified positionally (i.e. integer) or by name::
 
             @collection.replaces(2)
             def __setitem__(self, index, item): ...
         """
-        
+
         def decorator(fn):
             setattr(fn, '_sa_instrument_before', ('fire_append_event', arg))
             setattr(fn, '_sa_instrument_after', 'fire_remove_event')
@@ -374,7 +407,7 @@ class collection(object):
             return fn
         return decorator
     removes = classmethod(removes)
-    
+
     def removes_return(cls):
         """Mark the method as removing an entity in the collection.
 
@@ -417,7 +450,7 @@ def collection_iter(collection):
         raise TypeError("'%s' object is not iterable" %
                         type(collection).__name__)
 
-    
+
 class CollectionAdapter(object):
     """Bridges between the ORM and arbitrary Python collections.
 
@@ -452,6 +485,39 @@ class CollectionAdapter(object):
         if hasattr(data, '_sa_on_link'):
             getattr(data, '_sa_on_link')(None)
 
+    def adapt_like_to_iterable(self, obj):
+        """Converts collection-compatible objects to an iterable of values.
+
+        Can be passed any type of object, and if the underlying collection
+        determines that it can be adapted into a stream of values it can
+        use, returns an iterable of values suitable for append()ing.
+
+        This method may raise TypeError or any other suitable exception
+        if adaptation fails.
+
+        If a converter implementation is not supplied on the collection,
+        a default duck-typing-based implementation is used.
+        """
+
+        converter = getattr(self._data(), '_sa_converter', None)
+        if converter is not None:
+            return converter(obj)
+
+        setting_type = sautil.duck_type_collection(obj)
+
+        if obj is None or setting_type != self.attr.collection_interface:
+            raise TypeError(
+                "Incompatible collection type: %s is not %s-like" %
+                (type(obj).__name__, self.attr.collection_interface.__name__))
+
+        # If the object is an adapted collection, return the (iterable) adapter.
+        if getattr(obj, '_sa_adapter', None) is not None:
+            return getattr(obj, '_sa_adapter')
+        elif setting_type == dict:
+            return getattr(obj, 'itervalues', getattr(obj, 'values'))()
+        else:
+            return iter(obj)
+
     def append_with_event(self, item, initiator=None):
         """Add an entity to the collection, firing mutation events."""
 
@@ -504,7 +570,7 @@ class CollectionAdapter(object):
         mutation, and should be left as None unless you are passing along
         an initiator value from a chained operation.
         """
-        
+
         if initiator is not False and item is not None:
             self.attr.fire_append_event(self.owner_state, item, initiator)
 
@@ -518,7 +584,7 @@ class CollectionAdapter(object):
 
         if initiator is not False and item is not None:
             self.attr.fire_remove_event(self.owner_state, item, initiator)
-    
+
     def __getstate__(self):
         return { 'key': self.attr.key,
                  'owner_state': self.owner_state,
@@ -598,7 +664,7 @@ def _instrument_class(cls):
     # FIXME: more formally document this as a decoratorless/Python 2.3
     # option for specifying instrumentation.  (likely doc'd here in code only,
     # not in online docs.)
-    # 
+    #
     # __instrumentation__ = {
     #   'rolename': 'methodname', # ...
     #   'methods': {
@@ -617,7 +683,7 @@ def _instrument_class(cls):
         raise exceptions.ArgumentError(
             "Can not instrument a built-in type. Use a "
             "subclass, even a trivial one.")
-    
+
     collection_type = sautil.duck_type_collection(cls)
     if collection_type in __interfaces:
         roles = __interfaces[collection_type].copy()
@@ -638,7 +704,8 @@ def _instrument_class(cls):
         # note role declarations
         if hasattr(method, '_sa_instrument_role'):
             role = method._sa_instrument_role
-            assert role in ('appender', 'remover', 'iterator', 'on_link')
+            assert role in ('appender', 'remover', 'iterator',
+                            'on_link', 'converter')
             roles[role] = name
 
         # transfer instrumentation requests from decorated function
@@ -691,7 +758,7 @@ def _instrument_class(cls):
     for method, (before, argument, after) in methods.items():
         setattr(cls, method,
                 _instrument_membership_mutator(getattr(cls, method),
-                                               before, argument, after))    
+                                               before, argument, after))
     # intern the role map
     for role, method in roles.items():
         setattr(cls, '_sa_%s' % role, getattr(cls, method))
@@ -736,7 +803,7 @@ def _instrument_membership_mutator(method, before, argument, after):
             executor = None
         else:
             executor = getattr(args[0], '_sa_adapter', None)
-            
+
         if before and executor:
             getattr(executor, before)(value, initiator)
 
@@ -762,7 +829,7 @@ def __set(collection, item, _sa_initiator=None):
         executor = getattr(collection, '_sa_adapter', None)
         if executor:
             getattr(executor, 'fire_append_event')(item, _sa_initiator)
-                                                  
+
 def __del(collection, item, _sa_initiator=None):
     """Run del events, may eventually be inlined into decorators."""
 
@@ -770,11 +837,11 @@ def __del(collection, item, _sa_initiator=None):
         executor = getattr(collection, '_sa_adapter', None)
         if executor:
             getattr(executor, 'fire_remove_event')(item, _sa_initiator)
-    
+
 def _list_decorators():
     """Hand-turned instrumentation wrappers that can decorate any list-like
     class."""
-    
+
     def _tidy(fn):
         setattr(fn, '_sa_instrumented', True)
         fn.__doc__ = getattr(getattr(list, fn.__name__), '__doc__')
@@ -868,7 +935,7 @@ def _list_decorators():
             fn(self, start, end, values)
         _tidy(__setslice__)
         return __setslice__
-    
+
     def __delslice__(fn):
         def __delslice__(self, start, end):
             for value in self[start:end]:
@@ -883,7 +950,7 @@ def _list_decorators():
                 self.append(value)
         _tidy(extend)
         return extend
-    
+
     def pop(fn):
         def pop(self, index=-1):
             item = fn(self, index)
@@ -1101,7 +1168,7 @@ class InstrumentedList(list):
        'remover': 'remove',
        'iterator': '__iter__', }
 
-class InstrumentedSet(sautil.Set): 
+class InstrumentedSet(sautil.Set):
     """An instrumented version of the built-in set (or Set)."""
 
     __instrumentation__ = {
@@ -1109,7 +1176,7 @@ class InstrumentedSet(sautil.Set):
        'remover': 'remove',
        'iterator': '__iter__', }
 
-class InstrumentedDict(dict): 
+class InstrumentedDict(dict):
     """An instrumented version of the built-in dict."""
 
     __instrumentation__ = {
@@ -1146,7 +1213,7 @@ class MappedCollection(dict):
     callable that takes an object and returns an object for use as a dictionary
     key.
     """
-    
+
     def __init__(self, keyfunc):
         """Create a new collection with keying provided by keyfunc.
 
@@ -1169,10 +1236,10 @@ class MappedCollection(dict):
         self.__setitem__(key, value, _sa_initiator)
     set = collection.internally_instrumented(set)
     set = collection.appender(set)
-    
+
     def remove(self, value, _sa_initiator=None):
         """Remove an item from the collection by value, consulting this instance's keyfunc for the key."""
-        
+
         key = self.keyfunc(value)
         # Let self[key] raise if key is not in this collection
         if self[key] != value:
@@ -1185,3 +1252,26 @@ class MappedCollection(dict):
         self.__delitem__(key, _sa_initiator)
     remove = collection.internally_instrumented(remove)
     remove = collection.remover(remove)
+
+    def _convert(self, dictlike):
+        """Validate and convert a dict-like object into values for set()ing.
+
+        This is called behind the scenes when a MappedCollection is replaced
+        entirely by another collection, as in::
+
+          myobj.mappedcollection = {'a':obj1, 'b': obj2} # ...
+
+        Raises a TypeError if the key in any (key, value) pair in the dictlike
+        object does not match the key that this collection's keyfunc would
+        have assigned for that value.
+        """
+
+        for incoming_key, value in sautil.dictlike_iteritems(dictlike):
+            new_key = self.keyfunc(value)
+            if incoming_key != new_key:
+                raise TypeError(
+                    "Found incompatible key %r for value %r; this collection's "
+                    "keying function requires a key of %r for this value." % (
+                    incoming_key, value, new_key))
+            yield value
+    _convert = collection.converter(_convert)
index 1fee4cef0a5c101c3045f26daac9cfaf888bf0f2..705168d2095b04c24b53f9a5ca519103f62599dc 100644 (file)
@@ -193,6 +193,30 @@ def duck_type_collection(specimen, default=None):
     else:
         return default
 
+def dictlike_iteritems(dictlike):
+    """Return a (key, value) iterator for almost any dict-like object."""
+
+    if hasattr(dictlike, 'iteritems'):
+        return dictlike.iteritems()
+    elif hasattr(dictlike, 'items'):
+        return iter(dictlike.items())
+
+    getter = getattr(dictlike, '__getitem__', getattr(dictlike, 'get', None))
+    if getter is None:
+        raise TypeError(
+            "Object '%r' is not dict-like" % dictlike)
+
+    if hasattr(dictlike, 'iterkeys'):
+        def iterator():
+            for key in dictlike.iterkeys():
+                yield key, getter(key)
+        return iterator()
+    elif hasattr(dictlike, 'keys'):
+        return iter([(key, getter(key)) for key in dictlike.keys()])
+    else:
+        raise TypeError(
+            "Object '%r' is not dict-like" % dictlike)
+
 def assert_arg_type(arg, argtype, name):
     if isinstance(arg, argtype):
         return arg
index 1cfcd8fb5aba4bb65141b6d04d4b5ff5251f6547..932ad876a21b906809bb113c4094e9c605eab4c0 100644 (file)
@@ -14,7 +14,7 @@ class OrderedDictTest(PersistTest):
 
         self.assert_(o.keys() == ['a', 'b', 'snack', 'c'])
         self.assert_(o.values() == [1, 2, 'attack', 3])
-    
+
         o.pop('snack')
 
         self.assert_(o.keys() == ['a', 'b', 'c'])
@@ -49,7 +49,7 @@ class ColumnCollectionTest(PersistTest):
             assert False
         except exceptions.ArgumentError, e:
             assert str(e) == "__contains__ requires a string argument"
-            
+
     def test_compare(self):
         cc1 = sql.ColumnCollection()
         cc2 = sql.ColumnCollection()
@@ -66,24 +66,25 @@ class ColumnCollectionTest(PersistTest):
 class ArgSingletonTest(unittest.TestCase):
     def test_cleanout(self):
         util.ArgSingleton.instances.clear()
-        
+
         class MyClass(object):
             __metaclass__ = util.ArgSingleton
             def __init__(self, x, y):
                 self.x = x
                 self.y = y
-        
+
         m1 = MyClass(3, 4)
         m2 = MyClass(1, 5)
         m3 = MyClass(3, 4)
         assert m1 is m3
         assert m2 is not m3
         assert len(util.ArgSingleton.instances) == 2
-        
+
         m1 = m2 = m3 = None
         MyClass.dispose(MyClass)
         assert len(util.ArgSingleton.instances) == 0
 
+
 class ImmutableSubclass(str):
     pass
 
@@ -222,5 +223,76 @@ class IdentitySetTest(unittest.TestCase):
         self.assertRaises(TypeError, hash, ids)
 
 
+class DictlikeIteritemsTest(unittest.TestCase):
+    baseline = set([('a', 1), ('b', 2), ('c', 3)])
+
+    def _ok(self, instance):
+        iterator = util.dictlike_iteritems(instance)
+        self.assertEquals(set(iterator), self.baseline)
+
+    def _notok(self, instance):
+        self.assertRaises(TypeError,
+                          util.dictlike_iteritems,
+                          instance)
+
+    def test_dict(self):
+        d = dict(a=1,b=2,c=3)
+        self._ok(d)
+
+    def test_subdict(self):
+        class subdict(dict):
+            pass
+        d = subdict(a=1,b=2,c=3)
+        self._ok(d)
+
+    def test_UserDict(self):
+        import UserDict
+        d = UserDict.UserDict(a=1,b=2,c=3)
+        self._ok(d)
+
+    def test_object(self):
+        self._notok(object())
+
+    def test_duck_1(self):
+        class duck1(object):
+            def iteritems(duck):
+                return iter(self.baseline)
+        self._ok(duck1())
+
+    def test_duck_2(self):
+        class duck2(object):
+            def items(duck):
+                return list(self.baseline)
+        self._ok(duck2())
+
+    def test_duck_3(self):
+        class duck3(object):
+            def iterkeys(duck):
+                return iter(['a', 'b', 'c'])
+            def __getitem__(duck, key):
+                return dict(a=1,b=2,c=3).get(key)
+        self._ok(duck3())
+
+    def test_duck_4(self):
+        class duck4(object):
+            def iterkeys(duck):
+                return iter(['a', 'b', 'c'])
+        self._notok(duck4())
+
+    def test_duck_5(self):
+        class duck5(object):
+            def keys(duck):
+                return ['a', 'b', 'c']
+            def get(duck, key):
+                return dict(a=1,b=2,c=3).get(key)
+        self._ok(duck5())
+
+    def test_duck_6(self):
+        class duck6(object):
+            def keys(duck):
+                return ['a', 'b', 'c']
+        self._notok(duck6())
+
+
 if __name__ == "__main__":
     testbase.main()
index 2accd2fc83f2a71f7a0aa0273c52f13bab6591ef..fe8b40255c8d59b69ca2ad3fbcb710e078626a34 100644 (file)
@@ -288,13 +288,13 @@ class CustomDictTest(DictTest):
         try:
             p1._children = []
             self.assert_(False)
-        except exceptions.ArgumentError:
+        except TypeError:
             self.assert_(True)
 
         try:
             p1._children = None
             self.assert_(False)
-        except exceptions.ArgumentError:
+        except TypeError:
             self.assert_(True)
 
         self.assertRaises(TypeError, set, [p1.children])
@@ -404,13 +404,13 @@ class SetTest(_CollectionOperations):
         try:
             p1._children = []
             self.assert_(False)
-        except exceptions.ArgumentError:
+        except TypeError:
             self.assert_(True)
 
         try:
             p1._children = None
             self.assert_(False)
-        except exceptions.ArgumentError:
+        except TypeError:
             self.assert_(True)
 
         self.assertRaises(TypeError, set, [p1.children])
index 5d1753909af804e4e8b0b2b6b945384aa1c4dc7c..43b2f41e25314d79deef4faf79954bcdefe79912 100644 (file)
@@ -269,7 +269,7 @@ class CollectionsTest(PersistTest):
         try:
             obj.attr = set([e4])
             self.assert_(False)
-        except exceptions.ArgumentError:
+        except TypeError:
             self.assert_(e4 not in canary.data)
             self.assert_(e3 in canary.data)
 
@@ -526,7 +526,7 @@ class CollectionsTest(PersistTest):
         try:
             obj.attr = [e4]
             self.assert_(False)
-        except exceptions.ArgumentError:
+        except TypeError:
             self.assert_(e4 not in canary.data)
             self.assert_(e3 in canary.data)
 
@@ -737,23 +737,42 @@ class CollectionsTest(PersistTest):
         self.assert_(e1 in canary.removed)
         self.assert_(e2 in canary.added)
 
+
+        # key validity on bulk assignment is a basic feature of MappedCollection
+        # but is not present in basic, @converter-less dict collections.
         e3 = creator()
-        real_dict = dict(keyignored1=e3)
-        obj.attr = real_dict
-        self.assert_(obj.attr is not real_dict)
-        self.assert_('keyignored1' not in obj.attr)
-        self.assert_(set(collections.collection_adapter(obj.attr)) == set([e3]))
-        self.assert_(e2 in canary.removed)
-        self.assert_(e3 in canary.added)
+        if isinstance(obj.attr, collections.MappedCollection):
+            real_dict = dict(badkey=e3)
+            try:
+                obj.attr = real_dict
+                self.assert_(False)
+            except TypeError:
+                pass
+            self.assert_(obj.attr is not real_dict)
+            self.assert_('badkey' not in obj.attr)
+            self.assertEquals(set(collections.collection_adapter(obj.attr)),
+                              set([e2]))
+            self.assert_(e3 not in canary.added)
+        else:
+            real_dict = dict(keyignored1=e3)
+            obj.attr = real_dict
+            self.assert_(obj.attr is not real_dict)
+            self.assert_('keyignored1' not in obj.attr)
+            self.assertEquals(set(collections.collection_adapter(obj.attr)),
+                              set([e3]))
+            self.assert_(e2 in canary.removed)
+            self.assert_(e3 in canary.added)
+
+        obj.attr = typecallable()
+        self.assertEquals(list(collections.collection_adapter(obj.attr)), [])
 
         e4 = creator()
         try:
             obj.attr = [e4]
             self.assert_(False)
-        except exceptions.ArgumentError:
+        except TypeError:
             self.assert_(e4 not in canary.data)
-            self.assert_(e3 in canary.data)
-        
+
     def test_dict(self):
         try:
             self._test_adapter(dict, dictable_entity,