From bd3816a1fd8ae4e0dfbbca3f148dd1e65f48f5c7 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 17 Jul 2007 04:25:09 +0000 Subject: [PATCH] - added operator support to class-instrumented attributes. you can now filter() (or whatever) using .==. for column based properties, all column operators work (i.e. ==, <, >, like(), in_(), etc.). For relation() and composite column properties, ==, !=, and == are implemented so far. [ticket:643] --- CHANGES | 7 +++ lib/sqlalchemy/orm/attributes.py | 4 +- lib/sqlalchemy/orm/interfaces.py | 11 +++-- lib/sqlalchemy/orm/properties.py | 75 ++++++++++++++++++++------------ lib/sqlalchemy/orm/query.py | 11 ++--- lib/sqlalchemy/sql.py | 9 +++- test/orm/mapper.py | 18 ++++++-- test/orm/query.py | 29 +++++++++++- 8 files changed, 118 insertions(+), 46 deletions(-) diff --git a/CHANGES b/CHANGES index d777d45dfc..e2d6010bf3 100644 --- a/CHANGES +++ b/CHANGES @@ -38,6 +38,13 @@ querying divergent criteria. ClauseElements at the front of filter_by() are removed (use filter()). + - added operator support to class-instrumented attributes. you can now + filter() (or whatever) using .==. + for column based properties, all column operators work (i.e. ==, <, >, + like(), in_(), etc.). For relation() and composite column properties, + ==, !=, and == are implemented so far. + [ticket:643] + - added composite column properties. using the composite(cls, *columns) function inside of the "properties" dict, instances of cls will be created/mapped to a single attribute, comprised of the values diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 84d464e204..ad9675f029 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -85,10 +85,10 @@ class InstrumentedAttribute(sql.Comparator): return self.comparator.compare_self() def operate(self, op, other): - return self.comparator.operate(op, other) + return op(self.comparator, other) def reverse_operate(self, op, other): - return self.comparator.reverse_operate(op, other) + return op(other, self.comparator) def hasparent(self, item, optimistic=False): """Return the boolean value of a `hasparent` flag attached to the given item. diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index b0a2399f3b..c9c2a45b51 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -325,14 +325,17 @@ class MapperProperty(object): raise NotImplementedError() - def compare(self, value): + def compare(self, operator, value): """Return a compare operation for the columns represented by this ``MapperProperty`` to the given value, which may be a - column value or an instance. + column value or an instance. 'operator' is an operator from + the operators module, or from sql.Comparator. + + By default uses the PropComparator attached to this MapperProperty + under the attribute name "comparator". """ - raise NotImplementedError() - + return operator(self.comparator, value) class PropComparator(sql.Comparator): """defines comparison operations for MapperProperty objects""" diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index a32354fcad..8a57b4a83e 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -15,7 +15,7 @@ from sqlalchemy import sql, schema, util, exceptions, sql_util, logging from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil -import sets, random +import operator from sqlalchemy.orm.interfaces import * __all__ = ['ColumnProperty', 'CompositeProperty', 'PropertyLoader', 'BackRef'] @@ -56,9 +56,6 @@ class ColumnProperty(StrategizedProperty): def merge(self, session, source, dest, _recursive): setattr(dest, self.key, getattr(source, self.key, None)) - def compare(self, value, op='=='): - return self.comparator == value - def get_col_value(self, column, value): return value @@ -84,7 +81,7 @@ class CompositeProperty(ColumnProperty): def __init__(self, class_, *columns, **kwargs): super(CompositeProperty, self).__init__(*columns, **kwargs) self.composite_class = class_ - self.comparator = None + self.comparator = CompositeProperty.Comparator(self) def copy(self): return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns) @@ -101,19 +98,21 @@ class CompositeProperty(ColumnProperty): if a is column: setattr(obj, b, value) - def compare(self, value, op='=='): - # TODO: build into operator framework - if op == '==': - return sql.and_([a==b for a, b in zip(self.columns, value.__colset__())]) - elif op == '!=': - return sql.or_([a!=b for a, b in zip(self.columns, value.__colset__())]) - def get_col_value(self, column, value): for a, b in zip(self.columns, value.__colset__()): if a is column: return b - + class Comparator(PropComparator): + def __eq__(self, other): + if other is None: + return sql.and_(*[a==None for a in self.prop.columns]) + else: + return sql.and_(*[a==b for a, b in zip(self.prop.columns, other.__colset__())]) + + def __ne__(self, other): + return sql.or_(*[a!=b for a, b in zip(self.prop.columns, other.__colset__())]) + class PropertyLoader(StrategizedProperty): """Describes an object property that holds a single item or list of items that correspond to a related database table. @@ -137,7 +136,7 @@ class PropertyLoader(StrategizedProperty): self.remote_side = util.to_set(remote_side) self.enable_typechecks = enable_typechecks self._parent_join_cache = {} - self.comparator = None + self.comparator = PropertyLoader.Comparator(self) if cascade is not None: self.cascade = mapperutil.CascadeOptions(cascade) @@ -162,22 +161,40 @@ class PropertyLoader(StrategizedProperty): self.backref = backref self.is_backref = is_backref - def compare(self, value, value_is_parent=False, op='=='): - if op == '==': - # optimized operation for ==, uses a lazy clause. - (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent) - bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) - - class Visitor(sql.ClauseVisitor): - def visit_bindparam(s, bindparam): - mapper = value_is_parent and self.parent or self.mapper - bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key]) - Visitor().traverse(criterion) - return criterion + class Comparator(PropComparator): + def __eq__(self, other): + if other is None: + return ~sql.exists([1], self.prop.primaryjoin) + else: + return self.prop._optimized_compare(other) + + def __ne__(self, other): + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])) + + def compare(self, op, value, value_is_parent=False): + if op == operator.eq: + if value is None: + return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin) + else: + return self._optimized_compare(value, value_is_parent=value_is_parent) else: - # TODO: build expressions like these into operator framework - return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))]) - + return op(self.comparator, value) + + def _optimized_compare(self, value, value_is_parent=False): + # optimized operation for ==, uses a lazy clause. + (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent) + bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) + + class Visitor(sql.ClauseVisitor): + def visit_bindparam(s, bindparam): + mapper = value_is_parent and self.parent or self.mapper + bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key]) + Visitor().traverse(criterion) + return criterion + private = property(lambda s:s.cascade.delete_orphan) def create_strategy(self): diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 3937149ee5..0537ee258e 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -7,6 +7,7 @@ from sqlalchemy import sql, util, exceptions, sql_util, logging, schema from sqlalchemy.orm import mapper, class_mapper, object_mapper from sqlalchemy.orm.interfaces import OperationContext +import operator __all__ = ['Query', 'QueryContext', 'SelectionContext'] @@ -120,7 +121,7 @@ class Query(object): mapper = object_mapper(instance) prop = mapper.get_property(property, resolve_synonyms=True) target = prop.mapper - criterion = prop.compare(instance, value_is_parent=True) + criterion = prop.compare(operator.eq, instance, value_is_parent=True) return Query(target, **kwargs).filter(criterion) query_from_parent = classmethod(query_from_parent) @@ -149,7 +150,7 @@ class Query(object): raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__)) else: prop = mapper.get_property(property, resolve_synonyms=True) - return self.filter(prop.compare(instance, value_is_parent=True)) + return self.filter(prop.compare(operator.eq, instance, value_is_parent=True)) def add_entity(self, entity): """add a mapped entity to the list of result columns to be returned. @@ -265,7 +266,7 @@ class Query(object): for key, value in kwargs.iteritems(): prop = joinpoint.get_property(key, resolve_synonyms=True) - c = prop.compare(value) + c = prop.compare(operator.eq, value) if alias is not None: sql_util.ClauseAdapter(alias).traverse(c) @@ -1011,9 +1012,9 @@ class Query(object): for key, value in params.iteritems(): (keys, prop) = self._locate_prop(key, start=start) if isinstance(prop, properties.PropertyLoader): - c = prop.compare(value) & self.join_via(keys[:-1]) + c = prop.compare(operator.eq, value) & self.join_via(keys[:-1]) else: - c = prop.compare(value) & self.join_via(keys) + c = prop.compare(operator.eq, value) & self.join_via(keys) if clause is None: clause = c else: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index e177f4194a..b6a843685c 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -870,8 +870,8 @@ class ClauseVisitor(object): """A class that knows how to traverse and visit ``ClauseElements``. - Each ``ClauseElement``'s accept_visitor() method will call a - corresponding visit_XXXX() method here. Traversal of a + Calls visit_XXX() methods dynamically generated for each particualr + ``ClauseElement`` subclass encountered. Traversal of a hierarchy of ``ClauseElements`` is achieved via the ``traverse()`` method, which is passed the lead ``ClauseElement``. @@ -885,6 +885,11 @@ class ClauseVisitor(object): (column_collections=False) or to return Schema-level items (schema_visitor=True). + ``ClauseVisitor`` also supports a simultaneous copy-and-traverse + operation, which will produce a copy of a given ``ClauseElement`` + structure while at the same time allowing ``ClauseVisitor`` subclasses + to modify the new structure in-place. + """ __traverse_options__ = {} diff --git a/test/orm/mapper.py b/test/orm/mapper.py index eb0d110a16..e6c03161c2 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -844,12 +844,15 @@ class CompositeTypesTest(ORMTest): edges = Table('edges', metadata, Column('id', Integer, primary_key=True), - Column('graph_id', Integer, ForeignKey('graphs.id'), nullable=False), + Column('graph_id', Integer, nullable=False), + Column('graph_version_id', Integer, nullable=False), Column('x1', Integer), Column('y1', Integer), Column('x2', Integer), - Column('y2', Integer)) - + Column('y2', Integer), + ForeignKeyConstraint(['graph_id', 'graph_version_id'], ['graphs.id', 'graphs.version_id']) + ) + def test_basic(self): class Point(object): def __init__(self, x, y): @@ -914,6 +917,15 @@ class CompositeTypesTest(ORMTest): assert e1.end == e2.end self.assert_sql_count(testbase.db, go, 1) + # test comparison of CompositeProperties to their object instances + g = sess.query(Graph).get([1, 1]) + assert sess.query(Edge).filter(Edge.start==Point(3, 4)).one() is g.edges[0] + + assert sess.query(Edge).filter(Edge.start!=Point(3, 4)).first() is g.edges[1] + + assert sess.query(Edge).filter(Edge.start==None).all() == [] + + def test_pk(self): """test using a composite type as a primary key""" diff --git a/test/orm/query.py b/test/orm/query.py index c5c48f4bb4..75885fb8ce 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -60,7 +60,7 @@ class QueryTest(testbase.ORMTest): def setup_mappers(self): mapper(User, users, properties={ - 'addresses':relation(Address), + 'addresses':relation(Address, backref='user'), 'orders':relation(Order, backref='user'), # o2m, m2o }) mapper(Address, addresses) @@ -196,7 +196,34 @@ class FilterTest(QueryTest): def test_onefilter(self): assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.name.endswith('ed')).all() + def test_contains(self): + """test comparing a collection to an object instance.""" + + sess = create_session() + address = sess.query(Address).get(3) + assert [User(id=8)] == sess.query(User).filter(User.addresses==address).all() + + assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all() + + assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all() + + def test_contains_m2m(self): + sess = create_session() + item = sess.query(Item).get(3) + assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items==item).all() + assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(Order.items!=item).all() + + def test_has(self): + """test scalar comparison to an object instance""" + + sess = create_session() + user = sess.query(User).get(8) + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user==user).all() + + assert [Address(id=1), Address(id=5)] == sess.query(Address).filter(Address.user!=user).all() + + class CountTest(QueryTest): def test_basic(self): assert 4 == create_session().query(User).count() -- 2.47.3