From: Mike Bayer Date: Mon, 20 Dec 2010 18:47:48 +0000 (-0500) Subject: - crudely, this replaces CompositeProperty's base to be X-Git-Tag: rel_0_7b1~139 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9b007ed28de0fda016511cf242d39afff798070a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - crudely, this replaces CompositeProperty's base to be DescriptorProperty. We have to lose mutability (yikes composites were using mutable too !). Also the getter is not particularly efficient since it recreates the composite every time, probably want to stick it in __dict__. also rewrite the unit tests --- diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index d1b725c8e2..f16376ce60 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -162,80 +162,6 @@ class ColumnProperty(StrategizedProperty): log.class_logger(ColumnProperty) -class CompositeProperty(ColumnProperty): - """subclasses ColumnProperty to provide composite type support.""" - - def __init__(self, class_, *columns, **kwargs): - super(CompositeProperty, self).__init__(*columns, **kwargs) - self._col_position_map = util.column_dict( - (c, i) for i, c - in enumerate(columns)) - self.composite_class = class_ - self.strategy_class = strategies.CompositeColumnLoader - - def copy(self): - return CompositeProperty( - deferred=self.deferred, - group=self.group, - composite_class=self.composite_class, - active_history=self.active_history, - *self.columns) - - def do_init(self): - # skip over ColumnProperty's do_init(), - # which issues assertions that do not apply to CompositeColumnProperty - super(ColumnProperty, self).do_init() - - def _getcommitted(self, state, dict_, column, passive=False): - # TODO: no coverage here - obj = state.get_impl(self.key).\ - get_committed_value(state, dict_, passive=passive) - return self.get_col_value(column, obj) - - def set_col_value(self, state, dict_, value, column): - obj = state.get_impl(self.key).get(state, dict_) - if obj is None: - obj = self.composite_class(*[None for c in self.columns]) - state.get_impl(self.key).set(state, state.dict, obj, None) - - if hasattr(obj, '__set_composite_values__'): - values = list(obj.__composite_values__()) - values[self._col_position_map[column]] = value - obj.__set_composite_values__(*values) - else: - setattr(obj, column.key, value) - - def get_col_value(self, column, value): - if value is None: - return None - for a, b in zip(self.columns, value.__composite_values__()): - if a is column: - return b - - class Comparator(PropComparator): - def __clause_element__(self): - if self.adapter: - # TODO: test coverage for adapted composite comparison - return expression.ClauseList( - *[self.adapter(x) for x in self.prop.columns]) - else: - return expression.ClauseList(*self.prop.columns) - - __hash__ = None - - def __eq__(self, other): - if other is None: - values = [None] * len(self.prop.columns) - else: - values = other.__composite_values__() - return sql.and_( - *[a==b for a, b in zip(self.prop.columns, values)]) - - def __ne__(self, other): - return sql.not_(self.__eq__(other)) - - def __str__(self): - return str(self.parent.class_.__name__) + "." + self.key class DescriptorProperty(MapperProperty): """:class:`MapperProperty` which proxies access to a @@ -243,7 +169,9 @@ class DescriptorProperty(MapperProperty): def instrument_class(self, mapper): from sqlalchemy.ext import hybrid - + + prop = self + # hackety hack hack class _ProxyImpl(object): accepts_scalar_loader = False @@ -251,7 +179,11 @@ class DescriptorProperty(MapperProperty): def __init__(self, key): self.key = key - + + if hasattr(prop, 'get_history'): + def get_history(self, state, dict_, **kw): + return prop.get_history(state, dict_, **kw) + if self.descriptor is None: desc = getattr(mapper.class_, self.key, None) if mapper._is_userland_descriptor(desc): @@ -296,7 +228,7 @@ class DescriptorProperty(MapperProperty): descriptor.expr = get_comparator descriptor.impl = _ProxyImpl(self.key) mapper.class_manager.instrument_attribute(self.key, descriptor) - + def setup(self, context, entity, path, adapter, **kwargs): pass @@ -307,6 +239,105 @@ class DescriptorProperty(MapperProperty): dest_state, dest_dict, load, _recursive): pass +class CompositeProperty(DescriptorProperty): + + def __init__(self, class_, *columns, **kwargs): + self.columns = columns + self.composite_class = class_ + self.active_history = kwargs.get('active_history', False) + self.deferred = kwargs.get('deferred', False) + self.group = kwargs.get('group', None) + + prop = self + def fget(instance): + return prop.composite_class( + *[getattr(instance, prop.parent._columntoproperty[col].key) + for col in prop.columns] + ) + def fset(instance, value): + if value is None: + fdel(instance) + else: + for col, value in zip(prop.columns, value.__composite_values__()): + setattr(instance, prop.parent._columntoproperty[col].key, value) + + def fdel(instance): + for col in prop.columns: + setattr(instance, prop.parent._columntoproperty[col].key, None) + self.descriptor = property(fget, fset, fdel) + + def get_history(self, state, dict_, **kw): + """Provided for userland code that uses attributes.get_history().""" + + added = [] + deleted = [] + + has_history = False + for col in self.columns: + key = self.parent._columntoproperty[col].key + hist = state.manager[key].impl.get_history(state, dict_) + if hist.has_changes(): + has_history = True + + added.extend(hist.non_deleted()) + if hist.deleted: + deleted.extend(hist.deleted) + else: + deleted.append(None) + + if has_history: + return attributes.History( + [self.composite_class(*added)], + (), + [self.composite_class(*deleted)] + ) + else: + return attributes.History( + (),[self.composite_class(*added)], () + ) + + def do_init(self): + for col in self.columns: + prop = self.parent._columntoproperty[col] + prop.active_history = self.active_history + if self.deferred: + prop.deferred = self.deferred + prop.strategy_class = strategies.DeferredColumnLoader + prop.group = self.group + # strategies ... + + def _comparator_factory(self, mapper): + return CompositeProperty.Comparator(self) + + class Comparator(PropComparator): + def __init__(self, prop, adapter=None): + self.prop = prop + self.adapter = adapter + + def __clause_element__(self): + if self.adapter: + # TODO: test coverage for adapted composite comparison + return expression.ClauseList( + *[self.adapter(x) for x in self.prop.columns]) + else: + return expression.ClauseList(*self.prop.columns) + + __hash__ = None + + def __eq__(self, other): + if other is None: + values = [None] * len(self.prop.columns) + else: + values = other.__composite_values__() + return sql.and_( + *[a==b for a, b in zip(self.prop.columns, values)]) + + def __ne__(self, other): + return sql.not_(self.__eq__(other)) + + def __str__(self): + return str(self.parent.class_.__name__) + "." + self.key + class ConcreteInheritedProperty(DescriptorProperty): """A 'do nothing' :class:`MapperProperty` that disables an attribute on a concrete subclass that is only present diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index c4619d3a72..21f22ef509 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -138,59 +138,6 @@ class ColumnLoader(LoaderStrategy): log.class_logger(ColumnLoader) -class CompositeColumnLoader(ColumnLoader): - """Strategize the loading of a composite column-based MapperProperty.""" - - def init_class_attribute(self, mapper): - self.is_class_level = True - self.logger.info("%s register managed composite attribute", self) - - def copy(obj): - if obj is None: - return None - return self.parent_property.\ - composite_class(*obj.__composite_values__()) - - def compare(a, b): - if a is None or b is None: - return a is b - - for col, aprop, bprop in zip(self.columns, - a.__composite_values__(), - b.__composite_values__()): - if not col.type.compare_values(aprop, bprop): - return False - else: - return True - - _register_attribute(self, mapper, useobject=False, - compare_function=compare, - copy_function=copy, - mutable_scalars=True, - active_history=self.parent_property.active_history, - ) - - def create_row_processor(self, selectcontext, path, mapper, - row, adapter): - key = self.key - columns = self.columns - composite_class = self.parent_property.composite_class - if adapter: - columns = [adapter.columns[c] for c in columns] - - for c in columns: - if c not in row: - def new_execute(state, dict_, row): - state.expire_attribute_pre_commit(dict_, key) - break - else: - def new_execute(state, dict_, row): - dict_[key] = composite_class(*[row[c] for c in columns]) - - return new_execute, None, None - -log.class_logger(CompositeColumnLoader) - class DeferredColumnLoader(LoaderStrategy): """Strategize the loading of a deferred column-based MapperProperty.""" diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index c94ef9b3fd..621e5f47ce 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import mapper, relationship, backref, \ validates, aliased, Mapper from sqlalchemy.orm import defer, deferred, synonym, attributes, \ column_property, composite, relationship, dynamic_loader, \ - comparable_property, AttributeExtension + comparable_property, AttributeExtension, Session from sqlalchemy.orm.instrumentation import ClassManager from test.lib.testing import eq_, AssertsCompiledSQL from test.orm import _base, _fixtures @@ -2213,39 +2213,34 @@ class CompositeTypesTest(_base.MappedTest): 'end': sa.orm.composite(Point, edges.c.x2, edges.c.y2) }) - sess = create_session() + sess = Session() g = Graph() g.id = 1 g.version_id=1 g.edges.append(Edge(Point(3, 4), Point(5, 6))) g.edges.append(Edge(Point(14, 5), Point(2, 7))) sess.add(g) - sess.flush() + sess.commit() - sess.expunge_all() g2 = sess.query(Graph).get([g.id, g.version_id]) for e1, e2 in zip(g.edges, g2.edges): eq_(e1.start, e2.start) eq_(e1.end, e2.end) g2.edges[1].end = Point(18, 4) - sess.flush() - sess.expunge_all() + sess.commit() + e = sess.query(Edge).get(g2.edges[1].id) eq_(e.end, Point(18, 4)) - - e.end.x = 19 - e.end.y = 5 - sess.flush() - sess.expunge_all() - eq_(sess.query(Edge).get(g2.edges[1].id).end, Point(19, 5)) - - g.edges[1].end = Point(19, 5) - + + e.end = Point(19, 5) + sess.commit() + g.id, g.version_id, g.edges sess.expunge_all() + def go(): - g2 = (sess.query(Graph). - options(sa.orm.joinedload('edges'))).get([g.id, g.version_id]) + g2 = sess.query(Graph).\ + options(sa.orm.joinedload('edges')).get([g.id, g.version_id]) for e1, e2 in zip(g.edges, g2.edges): eq_(e1.start, e2.start) eq_(e1.end, e2.end) @@ -2261,9 +2256,9 @@ class CompositeTypesTest(_base.MappedTest): # query by columns eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, 19, 5)]) - + e = g.edges[1] - e.end.x = e.end.y = None + del e.end sess.flush() eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, None, None)])