]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- composite now relates to its parent class in terms of MapperProperty,
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 15 Jan 2011 18:25:03 +0000 (13:25 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 15 Jan 2011 18:25:03 +0000 (13:25 -0500)
not Column.  This allows it to compose any mapped attributes, including
relationship().  [ticket:2024]

lib/sqlalchemy/orm/descriptor_props.py
test/orm/test_composites.py

index e6166aa9ecc14282ce6c862c8542446f21e0c018..79c57ac0e4159126ad69a567c4167c6ef17da015 100644 (file)
@@ -12,8 +12,9 @@ as actively in the load/persist ORM loop.
 
 from sqlalchemy.orm.interfaces import \
     MapperProperty, PropComparator, StrategizedProperty
+from sqlalchemy.orm.mapper import _none_set
 from sqlalchemy.orm import attributes
-from sqlalchemy import util, sql, exc as sa_exc, event
+from sqlalchemy import util, sql, exc as sa_exc, event, schema
 from sqlalchemy.sql import expression
 properties = util.importlater('sqlalchemy.orm', 'properties')
 
@@ -72,8 +73,8 @@ class DescriptorProperty(MapperProperty):
 
 class CompositeProperty(DescriptorProperty):
 
-    def __init__(self, class_, *columns, **kwargs):
-        self.columns = columns
+    def __init__(self, class_, *attrs, **kwargs):
+        self.attrs = attrs
         self.composite_class = class_
         self.active_history = kwargs.get('active_history', False)
         self.deferred = kwargs.get('deferred', False)
@@ -90,6 +91,7 @@ class CompositeProperty(DescriptorProperty):
         has been associated with its parent mapper.
 
         """
+        self._init_props()
         self._setup_arguments_on_columns()
 
     def _create_descriptor(self):
@@ -101,12 +103,17 @@ class CompositeProperty(DescriptorProperty):
         def fget(instance):
             dict_ = attributes.instance_dict(instance)
 
-            # key not present, assume the columns aren't
-            # loaded.  The load events will establish
-            # the item.
             if self.key not in dict_:
-                for key in self._attribute_keys:
-                    getattr(instance, key)
+                # key not present.  Iterate through related
+                # attributes, retrieve their values.  This
+                # ensures they all load.
+                values = [getattr(instance, key) for key in self._attribute_keys]
+
+                # usually, the load() event will have loaded our key
+                # at this point, unless we only loaded relationship()
+                # attributes above.  Populate here if that's the case.
+                if self.key not in dict_ and not _none_set.issuperset(values):
+                    dict_[self.key] = self.composite_class(*values)
 
             return dict_.get(self.key, None)
 
@@ -138,13 +145,30 @@ class CompositeProperty(DescriptorProperty):
 
         self.descriptor = property(fget, fset, fdel)
 
+    @util.memoized_property
+    def _comparable_elements(self):
+        return [
+            getattr(self.parent.class_, prop.key)
+            for prop in self.props
+        ]
+
+    def _init_props(self):
+        self.props = props = []
+        for attr in self.attrs:
+            if isinstance(attr, basestring):
+                prop = self.parent.get_property(attr)
+            elif isinstance(attr, schema.Column):
+                prop = self.parent._columntoproperty[attr]
+            elif isinstance(attr, attributes.InstrumentedAttribute):
+                prop = attr.property
+            props.append(prop)
+
     def _setup_arguments_on_columns(self):
         """Propagate configuration arguments made on this composite
         to the target columns, for those that apply.
 
         """
-        for col in self.columns:
-            prop = self.parent._columntoproperty[col]
+        for prop in self.props:
             prop.active_history = self.active_history
             if self.deferred:
                 prop.deferred = self.deferred
@@ -195,8 +219,7 @@ class CompositeProperty(DescriptorProperty):
     @util.memoized_property
     def _attribute_keys(self):
         return [
-            self.parent._columntoproperty[col].key
-            for col in self.columns
+            prop.key for prop in self.props
         ]
 
     def get_history(self, state, dict_, **kw):
@@ -206,8 +229,8 @@ class CompositeProperty(DescriptorProperty):
         deleted = []
 
         has_history = False
-        for col in self.columns:
-            key = self.parent._columntoproperty[col].key
+        for prop in self.props:
+            key = prop.key
             hist = state.manager[key].impl.get_history(state, dict_)
             if hist.has_changes():
                 has_history = True
@@ -241,19 +264,19 @@ class CompositeProperty(DescriptorProperty):
             if self.adapter:
                 # TODO: test coverage for adapted composite comparison
                 return expression.ClauseList(
-                            *[self.adapter(x) for x in self.prop.columns])
+                            *[self.adapter(x) for x in self.prop._comparable_elements])
             else:
-                return expression.ClauseList(*self.prop.columns)
+                return expression.ClauseList(*self.prop._comparable_elements)
 
         __hash__ = None
 
         def __eq__(self, other):
             if other is None:
-                values = [None] * len(self.prop.columns)
+                values = [None] * len(self.prop._comparable_elements)
             else:
                 values = other.__composite_values__()
             return sql.and_(
-                    *[a==b for a, b in zip(self.prop.columns, values)])
+                    *[a==b for a, b in zip(self.prop._comparable_elements, values)])
 
         def __ne__(self, other):
             return sql.not_(self.__eq__(other))
index 46d6bb44d921cd18aa336599e275fc69a8ca4aa0..5ebf46e0b9a076d9d4b024e495c74d67217760df 100644 (file)
@@ -426,3 +426,137 @@ class MappedSelectTest(_base.MappedTest):
             testing.db.execute(values.select()).fetchall(),
             [(1, 1, u'Red', u'5'), (2, 1, u'Blue', u'1')]
         )
+
+class ManyToOneTest(_base.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('a', 
+            metadata,
+            Column('id', Integer, primary_key=True, 
+                            test_needs_autoincrement=True),
+            Column('b1', String(20)),
+            Column('b2_id', Integer, ForeignKey('b.id'))
+        )
+
+        Table('b', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(20))
+        )
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        class A(_base.ComparableEntity):
+            pass
+        class B(_base.ComparableEntity):
+            pass
+
+        class C(_base.BasicEntity):
+            def __init__(self, b1, b2):
+                self.b1, self.b2 = b1, b2
+
+            def __composite_values__(self):
+                return self.b1, self.b2
+
+            def __eq__(self, other):
+                return isinstance(other, C) and \
+                    other.b1 == self.b1 and \
+                    other.b2 == self.b2
+
+
+        mapper(A, a, properties={
+            'b2':relationship(B),
+            'c':composite(C, 'b1', 'b2')
+        })
+        mapper(B, b)
+
+    @testing.resolve_artifact_names
+    def test_persist(self):
+        sess = Session()
+        sess.add(A(c=C('b1', B(data='b2'))))
+        sess.commit()
+
+        a1 = sess.query(A).one()
+        eq_(a1.c, C('b1', B(data='b2')))
+
+    @testing.resolve_artifact_names
+    def test_query(self):
+        sess = Session()
+        b1, b2 = B(data='b1'), B(data='b2')
+        a1 = A(c=C('a1b1', b1))
+        a2 = A(c=C('a2b1', b2))
+        sess.add_all([a1, a2])
+        sess.commit()
+
+        eq_(
+            sess.query(A).filter(A.c==C('a2b1', b2)).one(),
+            a2
+        )
+
+class ConfigurationTest(_base.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('edge', metadata,
+            Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+            Column('x1', Integer),
+            Column('y1', Integer),
+            Column('x2', Integer),
+            Column('y2', Integer),
+        )
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        class Point(_base.BasicEntity):
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+            def __composite_values__(self):
+                return [self.x, self.y]
+            def __eq__(self, other):
+                return isinstance(other, Point) and \
+                        other.x == self.x and \
+                        other.y == self.y
+            def __ne__(self, other):
+                return not isinstance(other, Point) or \
+                    not self.__eq__(other)
+
+        class Edge(_base.ComparableEntity):
+            pass
+
+    @testing.resolve_artifact_names
+    def _test_roundtrip(self):
+        e1 = Edge(start=Point(3, 4), end=Point(5, 6))
+        sess = Session()
+        sess.add(e1)
+        sess.commit()
+
+        eq_(
+            sess.query(Edge).one(),
+            Edge(start=Point(3, 4), end=Point(5, 6))
+        )
+
+    @testing.resolve_artifact_names
+    def test_columns(self):
+        mapper(Edge, edge, properties={
+            'start':sa.orm.composite(Point, edge.c.x1, edge.c.y1),
+            'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2)
+        })
+
+        self._test_roundtrip()
+
+    @testing.resolve_artifact_names
+    def test_attributes(self):
+        m = mapper(Edge, edge)
+        m.add_property('start', sa.orm.composite(Point, Edge.x1, Edge.y1))
+        m.add_property('end', sa.orm.composite(Point, Edge.x2, Edge.y2))
+
+        self._test_roundtrip()
+
+    @testing.resolve_artifact_names
+    def test_strings(self):
+        m = mapper(Edge, edge)
+        m.add_property('start', sa.orm.composite(Point, 'x1', 'y1'))
+        m.add_property('end', sa.orm.composite(Point, 'x2', 'y2'))
+
+        self._test_roundtrip()
\ No newline at end of file