From: Mike Bayer Date: Sat, 15 Jan 2011 18:25:03 +0000 (-0500) Subject: - composite now relates to its parent class in terms of MapperProperty, X-Git-Tag: rel_0_7b1~76 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c3fb278063022575bbf2cb5e5e48025dd006d9b5;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - composite now relates to its parent class in terms of MapperProperty, not Column. This allows it to compose any mapped attributes, including relationship(). [ticket:2024] --- diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index e6166aa9ec..79c57ac0e4 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -12,8 +12,9 @@ as actively in the load/persist ORM loop. from sqlalchemy.orm.interfaces import \ MapperProperty, PropComparator, StrategizedProperty +from sqlalchemy.orm.mapper import _none_set from sqlalchemy.orm import attributes -from sqlalchemy import util, sql, exc as sa_exc, event +from sqlalchemy import util, sql, exc as sa_exc, event, schema from sqlalchemy.sql import expression properties = util.importlater('sqlalchemy.orm', 'properties') @@ -72,8 +73,8 @@ class DescriptorProperty(MapperProperty): class CompositeProperty(DescriptorProperty): - def __init__(self, class_, *columns, **kwargs): - self.columns = columns + def __init__(self, class_, *attrs, **kwargs): + self.attrs = attrs self.composite_class = class_ self.active_history = kwargs.get('active_history', False) self.deferred = kwargs.get('deferred', False) @@ -90,6 +91,7 @@ class CompositeProperty(DescriptorProperty): has been associated with its parent mapper. """ + self._init_props() self._setup_arguments_on_columns() def _create_descriptor(self): @@ -101,12 +103,17 @@ class CompositeProperty(DescriptorProperty): def fget(instance): dict_ = attributes.instance_dict(instance) - # key not present, assume the columns aren't - # loaded. The load events will establish - # the item. if self.key not in dict_: - for key in self._attribute_keys: - getattr(instance, key) + # key not present. Iterate through related + # attributes, retrieve their values. This + # ensures they all load. + values = [getattr(instance, key) for key in self._attribute_keys] + + # usually, the load() event will have loaded our key + # at this point, unless we only loaded relationship() + # attributes above. Populate here if that's the case. + if self.key not in dict_ and not _none_set.issuperset(values): + dict_[self.key] = self.composite_class(*values) return dict_.get(self.key, None) @@ -138,13 +145,30 @@ class CompositeProperty(DescriptorProperty): self.descriptor = property(fget, fset, fdel) + @util.memoized_property + def _comparable_elements(self): + return [ + getattr(self.parent.class_, prop.key) + for prop in self.props + ] + + def _init_props(self): + self.props = props = [] + for attr in self.attrs: + if isinstance(attr, basestring): + prop = self.parent.get_property(attr) + elif isinstance(attr, schema.Column): + prop = self.parent._columntoproperty[attr] + elif isinstance(attr, attributes.InstrumentedAttribute): + prop = attr.property + props.append(prop) + def _setup_arguments_on_columns(self): """Propagate configuration arguments made on this composite to the target columns, for those that apply. """ - for col in self.columns: - prop = self.parent._columntoproperty[col] + for prop in self.props: prop.active_history = self.active_history if self.deferred: prop.deferred = self.deferred @@ -195,8 +219,7 @@ class CompositeProperty(DescriptorProperty): @util.memoized_property def _attribute_keys(self): return [ - self.parent._columntoproperty[col].key - for col in self.columns + prop.key for prop in self.props ] def get_history(self, state, dict_, **kw): @@ -206,8 +229,8 @@ class CompositeProperty(DescriptorProperty): deleted = [] has_history = False - for col in self.columns: - key = self.parent._columntoproperty[col].key + for prop in self.props: + key = prop.key hist = state.manager[key].impl.get_history(state, dict_) if hist.has_changes(): has_history = True @@ -241,19 +264,19 @@ class CompositeProperty(DescriptorProperty): if self.adapter: # TODO: test coverage for adapted composite comparison return expression.ClauseList( - *[self.adapter(x) for x in self.prop.columns]) + *[self.adapter(x) for x in self.prop._comparable_elements]) else: - return expression.ClauseList(*self.prop.columns) + return expression.ClauseList(*self.prop._comparable_elements) __hash__ = None def __eq__(self, other): if other is None: - values = [None] * len(self.prop.columns) + values = [None] * len(self.prop._comparable_elements) else: values = other.__composite_values__() return sql.and_( - *[a==b for a, b in zip(self.prop.columns, values)]) + *[a==b for a, b in zip(self.prop._comparable_elements, values)]) def __ne__(self, other): return sql.not_(self.__eq__(other)) diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 46d6bb44d9..5ebf46e0b9 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -426,3 +426,137 @@ class MappedSelectTest(_base.MappedTest): testing.db.execute(values.select()).fetchall(), [(1, 1, u'Red', u'5'), (2, 1, u'Blue', u'1')] ) + +class ManyToOneTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('a', + metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('b1', String(20)), + Column('b2_id', Integer, ForeignKey('b.id')) + ) + + Table('b', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(20)) + ) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + class A(_base.ComparableEntity): + pass + class B(_base.ComparableEntity): + pass + + class C(_base.BasicEntity): + def __init__(self, b1, b2): + self.b1, self.b2 = b1, b2 + + def __composite_values__(self): + return self.b1, self.b2 + + def __eq__(self, other): + return isinstance(other, C) and \ + other.b1 == self.b1 and \ + other.b2 == self.b2 + + + mapper(A, a, properties={ + 'b2':relationship(B), + 'c':composite(C, 'b1', 'b2') + }) + mapper(B, b) + + @testing.resolve_artifact_names + def test_persist(self): + sess = Session() + sess.add(A(c=C('b1', B(data='b2')))) + sess.commit() + + a1 = sess.query(A).one() + eq_(a1.c, C('b1', B(data='b2'))) + + @testing.resolve_artifact_names + def test_query(self): + sess = Session() + b1, b2 = B(data='b1'), B(data='b2') + a1 = A(c=C('a1b1', b1)) + a2 = A(c=C('a2b1', b2)) + sess.add_all([a1, a2]) + sess.commit() + + eq_( + sess.query(A).filter(A.c==C('a2b1', b2)).one(), + a2 + ) + +class ConfigurationTest(_base.MappedTest): + @classmethod + def define_tables(cls, metadata): + Table('edge', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), + Column('x1', Integer), + Column('y1', Integer), + Column('x2', Integer), + Column('y2', Integer), + ) + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + class Point(_base.BasicEntity): + def __init__(self, x, y): + self.x = x + self.y = y + def __composite_values__(self): + return [self.x, self.y] + def __eq__(self, other): + return isinstance(other, Point) and \ + other.x == self.x and \ + other.y == self.y + def __ne__(self, other): + return not isinstance(other, Point) or \ + not self.__eq__(other) + + class Edge(_base.ComparableEntity): + pass + + @testing.resolve_artifact_names + def _test_roundtrip(self): + e1 = Edge(start=Point(3, 4), end=Point(5, 6)) + sess = Session() + sess.add(e1) + sess.commit() + + eq_( + sess.query(Edge).one(), + Edge(start=Point(3, 4), end=Point(5, 6)) + ) + + @testing.resolve_artifact_names + def test_columns(self): + mapper(Edge, edge, properties={ + 'start':sa.orm.composite(Point, edge.c.x1, edge.c.y1), + 'end': sa.orm.composite(Point, edge.c.x2, edge.c.y2) + }) + + self._test_roundtrip() + + @testing.resolve_artifact_names + def test_attributes(self): + m = mapper(Edge, edge) + m.add_property('start', sa.orm.composite(Point, Edge.x1, Edge.y1)) + m.add_property('end', sa.orm.composite(Point, Edge.x2, Edge.y2)) + + self._test_roundtrip() + + @testing.resolve_artifact_names + def test_strings(self): + m = mapper(Edge, edge) + m.add_property('start', sa.orm.composite(Point, 'x1', 'y1')) + m.add_property('end', sa.orm.composite(Point, 'x2', 'y2')) + + self._test_roundtrip() \ No newline at end of file