]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- crudely, this replaces CompositeProperty's base to be
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Dec 2010 18:47:48 +0000 (13:47 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Dec 2010 18:47:48 +0000 (13:47 -0500)
DescriptorProperty.   We have to lose mutability (yikes
composites were using mutable too !).   Also the getter
is not particularly efficient since it recreates the composite
every time, probably want to stick it in __dict__.  also
rewrite the unit tests

lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
test/orm/test_mapper.py

index d1b725c8e289b5e75b316fc2fdf1eaf8250515c2..f16376ce60f1e6a17e1999bfbb1c3f01606bcc65 100644 (file)
@@ -162,80 +162,6 @@ class ColumnProperty(StrategizedProperty):
 
 log.class_logger(ColumnProperty)
 
-class CompositeProperty(ColumnProperty):
-    """subclasses ColumnProperty to provide composite type support."""
-    
-    def __init__(self, class_, *columns, **kwargs):
-        super(CompositeProperty, self).__init__(*columns, **kwargs)
-        self._col_position_map = util.column_dict(
-                                            (c, i) for i, c 
-                                            in enumerate(columns))
-        self.composite_class = class_
-        self.strategy_class = strategies.CompositeColumnLoader
-
-    def copy(self):
-        return CompositeProperty(
-                        deferred=self.deferred, 
-                        group=self.group,
-                        composite_class=self.composite_class, 
-                        active_history=self.active_history,
-                        *self.columns)
-
-    def do_init(self):
-        # skip over ColumnProperty's do_init(),
-        # which issues assertions that do not apply to CompositeColumnProperty
-        super(ColumnProperty, self).do_init()
-
-    def _getcommitted(self, state, dict_, column, passive=False):
-        # TODO: no coverage here
-        obj = state.get_impl(self.key).\
-                        get_committed_value(state, dict_, passive=passive)
-        return self.get_col_value(column, obj)
-    
-    def set_col_value(self, state, dict_, value, column):
-        obj = state.get_impl(self.key).get(state, dict_)
-        if obj is None:
-            obj = self.composite_class(*[None for c in self.columns])
-            state.get_impl(self.key).set(state, state.dict, obj, None)
-
-        if hasattr(obj, '__set_composite_values__'):
-            values = list(obj.__composite_values__())
-            values[self._col_position_map[column]] = value
-            obj.__set_composite_values__(*values)
-        else:
-            setattr(obj, column.key, value)
-            
-    def get_col_value(self, column, value):
-        if value is None:
-            return None
-        for a, b in zip(self.columns, value.__composite_values__()):
-            if a is column:
-                return b
-
-    class Comparator(PropComparator):
-        def __clause_element__(self):
-            if self.adapter:
-                # TODO: test coverage for adapted composite comparison
-                return expression.ClauseList(
-                            *[self.adapter(x) for x in self.prop.columns])
-            else:
-                return expression.ClauseList(*self.prop.columns)
-        
-        __hash__ = None
-        
-        def __eq__(self, other):
-            if other is None:
-                values = [None] * len(self.prop.columns)
-            else:
-                values = other.__composite_values__()
-            return sql.and_(
-                    *[a==b for a, b in zip(self.prop.columns, values)])
-            
-        def __ne__(self, other):
-            return sql.not_(self.__eq__(other))
-
-    def __str__(self):
-        return str(self.parent.class_.__name__) + "." + self.key
 
 class DescriptorProperty(MapperProperty):
     """:class:`MapperProperty` which proxies access to a 
@@ -243,7 +169,9 @@ class DescriptorProperty(MapperProperty):
 
     def instrument_class(self, mapper):
         from sqlalchemy.ext import hybrid
-
+        
+        prop = self
+        
         # hackety hack hack
         class _ProxyImpl(object):
             accepts_scalar_loader = False
@@ -251,7 +179,11 @@ class DescriptorProperty(MapperProperty):
 
             def __init__(self, key):
                 self.key = key
-
+            
+            if hasattr(prop, 'get_history'):
+                def get_history(self, state, dict_, **kw):
+                    return prop.get_history(state, dict_, **kw)
+                
         if self.descriptor is None:
             desc = getattr(mapper.class_, self.key, None)
             if mapper._is_userland_descriptor(desc):
@@ -296,7 +228,7 @@ class DescriptorProperty(MapperProperty):
         descriptor.expr = get_comparator
         descriptor.impl = _ProxyImpl(self.key)
         mapper.class_manager.instrument_attribute(self.key, descriptor)
-
+    
     def setup(self, context, entity, path, adapter, **kwargs):
         pass
 
@@ -307,6 +239,105 @@ class DescriptorProperty(MapperProperty):
                 dest_state, dest_dict, load, _recursive):
         pass
 
+class CompositeProperty(DescriptorProperty):
+    
+    def __init__(self, class_, *columns, **kwargs):
+        self.columns = columns
+        self.composite_class = class_
+        self.active_history = kwargs.get('active_history', False)
+        self.deferred = kwargs.get('deferred', False)
+        self.group = kwargs.get('group', None)
+        
+        prop = self
+        def fget(instance):
+            return prop.composite_class(
+                *[getattr(instance, prop.parent._columntoproperty[col].key) 
+                for col in prop.columns]
+            )
+        def fset(instance, value):
+            if value is None:
+                fdel(instance)
+            else:
+                for col, value in zip(prop.columns, value.__composite_values__()):
+                    setattr(instance, prop.parent._columntoproperty[col].key, value)
+        
+        def fdel(instance):
+            for col in prop.columns:
+                setattr(instance, prop.parent._columntoproperty[col].key, None)
+        self.descriptor = property(fget, fset, fdel)
+        
+    def get_history(self, state, dict_, **kw):
+        """Provided for userland code that uses attributes.get_history()."""
+        
+        added = []
+        deleted = []
+        
+        has_history = False
+        for col in self.columns:
+            key = self.parent._columntoproperty[col].key
+            hist = state.manager[key].impl.get_history(state, dict_)
+            if hist.has_changes():
+                has_history = True
+            
+            added.extend(hist.non_deleted())
+            if hist.deleted:
+                deleted.extend(hist.deleted)
+            else:
+                deleted.append(None)
+        
+        if has_history:
+            return attributes.History(
+                [self.composite_class(*added)],
+                (),
+                [self.composite_class(*deleted)]
+            )
+        else:
+            return attributes.History(
+                (),[self.composite_class(*added)], ()
+            )
+
+    def do_init(self):
+        for col in self.columns:
+            prop = self.parent._columntoproperty[col]
+            prop.active_history = self.active_history
+            if self.deferred:
+                prop.deferred = self.deferred
+                prop.strategy_class = strategies.DeferredColumnLoader
+            prop.group = self.group
+        # strategies ...
+
+    def _comparator_factory(self, mapper):
+        return CompositeProperty.Comparator(self)
+
+    class Comparator(PropComparator):
+        def __init__(self, prop, adapter=None):
+            self.prop = prop
+            self.adapter = adapter
+            
+        def __clause_element__(self):
+            if self.adapter:
+                # TODO: test coverage for adapted composite comparison
+                return expression.ClauseList(
+                            *[self.adapter(x) for x in self.prop.columns])
+            else:
+                return expression.ClauseList(*self.prop.columns)
+        
+        __hash__ = None
+        
+        def __eq__(self, other):
+            if other is None:
+                values = [None] * len(self.prop.columns)
+            else:
+                values = other.__composite_values__()
+            return sql.and_(
+                    *[a==b for a, b in zip(self.prop.columns, values)])
+            
+        def __ne__(self, other):
+            return sql.not_(self.__eq__(other))
+
+    def __str__(self):
+        return str(self.parent.class_.__name__) + "." + self.key
+
 class ConcreteInheritedProperty(DescriptorProperty):
     """A 'do nothing' :class:`MapperProperty` that disables 
     an attribute on a concrete subclass that is only present
index c4619d3a72e7022cc5b70877183cf8af02367d93..21f22ef5092c855d147818c642f8636c3318e7e2 100644 (file)
@@ -138,59 +138,6 @@ class ColumnLoader(LoaderStrategy):
 
 log.class_logger(ColumnLoader)
 
-class CompositeColumnLoader(ColumnLoader):
-    """Strategize the loading of a composite column-based MapperProperty."""
-
-    def init_class_attribute(self, mapper):
-        self.is_class_level = True
-        self.logger.info("%s register managed composite attribute", self)
-
-        def copy(obj):
-            if obj is None:
-                return None
-            return self.parent_property.\
-                        composite_class(*obj.__composite_values__())
-            
-        def compare(a, b):
-            if a is None or b is None:
-                return a is b
-                
-            for col, aprop, bprop in zip(self.columns,
-                                         a.__composite_values__(),
-                                         b.__composite_values__()):
-                if not col.type.compare_values(aprop, bprop):
-                    return False
-            else:
-                return True
-
-        _register_attribute(self, mapper, useobject=False,
-            compare_function=compare,
-            copy_function=copy,
-            mutable_scalars=True,
-            active_history=self.parent_property.active_history,
-        )
-
-    def create_row_processor(self, selectcontext, path, mapper, 
-                                                    row, adapter):
-        key = self.key
-        columns = self.columns
-        composite_class = self.parent_property.composite_class
-        if adapter:
-            columns = [adapter.columns[c] for c in columns]
-            
-        for c in columns:
-            if c not in row:
-                def new_execute(state, dict_, row):
-                    state.expire_attribute_pre_commit(dict_, key)
-                break
-        else:
-            def new_execute(state, dict_, row):
-                dict_[key] = composite_class(*[row[c] for c in columns])
-
-        return new_execute, None, None
-
-log.class_logger(CompositeColumnLoader)
-    
 class DeferredColumnLoader(LoaderStrategy):
     """Strategize the loading of a deferred column-based MapperProperty."""
 
index c94ef9b3fd0026862b9b3258fef7413e04d61095..621e5f47ced71cd5d0e8cb3bc73f0ac737f483b0 100644 (file)
@@ -11,7 +11,7 @@ from sqlalchemy.orm import mapper, relationship, backref, \
     validates, aliased, Mapper
 from sqlalchemy.orm import defer, deferred, synonym, attributes, \
     column_property, composite, relationship, dynamic_loader, \
-    comparable_property, AttributeExtension
+    comparable_property, AttributeExtension, Session
 from sqlalchemy.orm.instrumentation import ClassManager
 from test.lib.testing import eq_, AssertsCompiledSQL
 from test.orm import _base, _fixtures
@@ -2213,39 +2213,34 @@ class CompositeTypesTest(_base.MappedTest):
             'end': sa.orm.composite(Point, edges.c.x2, edges.c.y2)
         })
 
-        sess = create_session()
+        sess = Session()
         g = Graph()
         g.id = 1
         g.version_id=1
         g.edges.append(Edge(Point(3, 4), Point(5, 6)))
         g.edges.append(Edge(Point(14, 5), Point(2, 7)))
         sess.add(g)
-        sess.flush()
+        sess.commit()
 
-        sess.expunge_all()
         g2 = sess.query(Graph).get([g.id, g.version_id])
         for e1, e2 in zip(g.edges, g2.edges):
             eq_(e1.start, e2.start)
             eq_(e1.end, e2.end)
 
         g2.edges[1].end = Point(18, 4)
-        sess.flush()
-        sess.expunge_all()
+        sess.commit()
+
         e = sess.query(Edge).get(g2.edges[1].id)
         eq_(e.end, Point(18, 4))
-
-        e.end.x = 19
-        e.end.y = 5
-        sess.flush()
-        sess.expunge_all()
-        eq_(sess.query(Edge).get(g2.edges[1].id).end, Point(19, 5))
-
-        g.edges[1].end = Point(19, 5)
-
+        
+        e.end = Point(19, 5)
+        sess.commit()
+        g.id, g.version_id, g.edges
         sess.expunge_all()
+        
         def go():
-            g2 = (sess.query(Graph).
-                  options(sa.orm.joinedload('edges'))).get([g.id, g.version_id])
+            g2 = sess.query(Graph).\
+                  options(sa.orm.joinedload('edges')).get([g.id, g.version_id])
             for e1, e2 in zip(g.edges, g2.edges):
                 eq_(e1.start, e2.start)
                 eq_(e1.end, e2.end)
@@ -2261,9 +2256,9 @@ class CompositeTypesTest(_base.MappedTest):
 
         # query by columns
         eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, 19, 5)])
-
+        
         e = g.edges[1]
-        e.end.x = e.end.y = None
+        del e.end
         sess.flush()
         eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, None, None)])