from sqlalchemy.orm.interfaces import \
MapperProperty, PropComparator, StrategizedProperty
+from sqlalchemy.orm.mapper import _none_set
from sqlalchemy.orm import attributes
-from sqlalchemy import util, sql, exc as sa_exc, event
+from sqlalchemy import util, sql, exc as sa_exc, event, schema
from sqlalchemy.sql import expression
properties = util.importlater('sqlalchemy.orm', 'properties')
class CompositeProperty(DescriptorProperty):
- def __init__(self, class_, *columns, **kwargs):
- self.columns = columns
+ def __init__(self, class_, *attrs, **kwargs):
+ self.attrs = attrs
self.composite_class = class_
self.active_history = kwargs.get('active_history', False)
self.deferred = kwargs.get('deferred', False)
has been associated with its parent mapper.
"""
+ self._init_props()
self._setup_arguments_on_columns()
def _create_descriptor(self):
def fget(instance):
dict_ = attributes.instance_dict(instance)
- # key not present, assume the columns aren't
- # loaded. The load events will establish
- # the item.
if self.key not in dict_:
- for key in self._attribute_keys:
- getattr(instance, key)
+ # key not present. Iterate through related
+ # attributes, retrieve their values. This
+ # ensures they all load.
+ values = [getattr(instance, key) for key in self._attribute_keys]
+
+ # usually, the load() event will have loaded our key
+ # at this point, unless we only loaded relationship()
+ # attributes above. Populate here if that's the case.
+ if self.key not in dict_ and not _none_set.issuperset(values):
+ dict_[self.key] = self.composite_class(*values)
return dict_.get(self.key, None)
self.descriptor = property(fget, fset, fdel)
+ @util.memoized_property
+ def _comparable_elements(self):
+ return [
+ getattr(self.parent.class_, prop.key)
+ for prop in self.props
+ ]
+
+ def _init_props(self):
+ self.props = props = []
+ for attr in self.attrs:
+ if isinstance(attr, basestring):
+ prop = self.parent.get_property(attr)
+ elif isinstance(attr, schema.Column):
+ prop = self.parent._columntoproperty[attr]
+ elif isinstance(attr, attributes.InstrumentedAttribute):
+ prop = attr.property
+ props.append(prop)
+
def _setup_arguments_on_columns(self):
"""Propagate configuration arguments made on this composite
to the target columns, for those that apply.
"""
- for col in self.columns:
- prop = self.parent._columntoproperty[col]
+ for prop in self.props:
prop.active_history = self.active_history
if self.deferred:
prop.deferred = self.deferred
@util.memoized_property
def _attribute_keys(self):
return [
- self.parent._columntoproperty[col].key
- for col in self.columns
+ prop.key for prop in self.props
]
def get_history(self, state, dict_, **kw):
deleted = []
has_history = False
- for col in self.columns:
- key = self.parent._columntoproperty[col].key
+ for prop in self.props:
+ key = prop.key
hist = state.manager[key].impl.get_history(state, dict_)
if hist.has_changes():
has_history = True
if self.adapter:
# TODO: test coverage for adapted composite comparison
return expression.ClauseList(
- *[self.adapter(x) for x in self.prop.columns])
+ *[self.adapter(x) for x in self.prop._comparable_elements])
else:
- return expression.ClauseList(*self.prop.columns)
+ return expression.ClauseList(*self.prop._comparable_elements)
__hash__ = None
def __eq__(self, other):
if other is None:
- values = [None] * len(self.prop.columns)
+ values = [None] * len(self.prop._comparable_elements)
else:
values = other.__composite_values__()
return sql.and_(
- *[a==b for a, b in zip(self.prop.columns, values)])
+ *[a==b for a, b in zip(self.prop._comparable_elements, values)])
def __ne__(self, other):
return sql.not_(self.__eq__(other))
testing.db.execute(values.select()).fetchall(),
[(1, 1, u'Red', u'5'), (2, 1, u'Blue', u'1')]
)
+
+class ManyToOneTest(_base.MappedTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ Table('a',
+ metadata,
+ Column('id', Integer, primary_key=True,
+ test_needs_autoincrement=True),
+ Column('b1', String(20)),
+ Column('b2_id', Integer, ForeignKey('b.id'))
+ )
+
+ Table('b', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(20))
+ )
+
+ @classmethod
+ @testing.resolve_artifact_names
+ def setup_mappers(cls):
+ class A(_base.ComparableEntity):
+ pass
+ class B(_base.ComparableEntity):
+ pass
+
+ class C(_base.BasicEntity):
+ def __init__(self, b1, b2):
+ self.b1, self.b2 = b1, b2
+
+ def __composite_values__(self):
+ return self.b1, self.b2
+
+ def __eq__(self, other):
+ return isinstance(other, C) and \
+ other.b1 == self.b1 and \
+ other.b2 == self.b2
+
+
+ mapper(A, a, properties={
+ 'b2':relationship(B),
+ 'c':composite(C, 'b1', 'b2')
+ })
+ mapper(B, b)
+
+ @testing.resolve_artifact_names
+ def test_persist(self):
+ sess = Session()
+ sess.add(A(c=C('b1', B(data='b2'))))
+ sess.commit()
+
+ a1 = sess.query(A).one()
+ eq_(a1.c, C('b1', B(data='b2')))
+
+ @testing.resolve_artifact_names
+ def test_query(self):
+ sess = Session()
+ b1, b2 = B(data='b1'), B(data='b2')
+ a1 = A(c=C('a1b1', b1))
+ a2 = A(c=C('a2b1', b2))
+ sess.add_all([a1, a2])
+ sess.commit()
+
+ eq_(
+ sess.query(A).filter(A.c==C('a2b1', b2)).one(),
+ a2
+ )
+
+class ConfigurationTest(_base.MappedTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ Table('edge', metadata,
+ Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
+ Column('x1', Integer),
+ Column('y1', Integer),
+ Column('x2', Integer),
+ Column('y2', Integer),
+ )
+
+ @classmethod
+ @testing.resolve_artifact_names
+ def setup_mappers(cls):
+ class Point(_base.BasicEntity):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+ def __composite_values__(self):
+ return [self.x, self.y]
+ def __eq__(self, other):
+ return isinstance(other, Point) and \
+ other.x == self.x and \
+ other.y == self.y
+ def __ne__(self, other):
+ return not isinstance(other, Point) or \
+ not self.__eq__(other)
+
+ class Edge(_base.ComparableEntity):
+ pass
+
+ @testing.resolve_artifact_names
+ def _test_roundtrip(self):
+ e1 = Edge(start=Point(3, 4), end=Point(5, 6))
+ sess = Session()
+ sess.add(e1)
+ sess.commit()
+
+ eq_(
+ sess.query(Edge).one(),
+ Edge(start=Point(3, 4), end=Point(5, 6))
+ )
+
+ @testing.resolve_artifact_names
+ def test_columns(self):
+ mapper(Edge, edge, properties={
+ 'start':sa.orm.composite(Point, edge.c.x1, edge.c.y1),
+ 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2)
+ })
+
+ self._test_roundtrip()
+
+ @testing.resolve_artifact_names
+ def test_attributes(self):
+ m = mapper(Edge, edge)
+ m.add_property('start', sa.orm.composite(Point, Edge.x1, Edge.y1))
+ m.add_property('end', sa.orm.composite(Point, Edge.x2, Edge.y2))
+
+ self._test_roundtrip()
+
+ @testing.resolve_artifact_names
+ def test_strings(self):
+ m = mapper(Edge, edge)
+ m.add_property('start', sa.orm.composite(Point, 'x1', 'y1'))
+ m.add_property('end', sa.orm.composite(Point, 'x2', 'y2'))
+
+ self._test_roundtrip()
\ No newline at end of file