filter() (or whatever) using <classname>.<attributename>==<whatever>.
for column based properties, all column operators work (i.e. ==, <, >,
like(), in_(), etc.). For relation() and composite column properties,
==<instance>, !=<instance>, and ==<None> are implemented so far.
[ticket:643]
querying divergent criteria. ClauseElements at the front of
filter_by() are removed (use filter()).
+ - added operator support to class-instrumented attributes. you can now
+ filter() (or whatever) using <classname>.<attributename>==<whatever>.
+ for column based properties, all column operators work (i.e. ==, <, >,
+ like(), in_(), etc.). For relation() and composite column properties,
+ ==<instance>, !=<instance>, and ==<None> are implemented so far.
+ [ticket:643]
+
- added composite column properties. using the composite(cls, *columns)
function inside of the "properties" dict, instances of cls will be
created/mapped to a single attribute, comprised of the values
return self.comparator.compare_self()
def operate(self, op, other):
- return self.comparator.operate(op, other)
+ return op(self.comparator, other)
def reverse_operate(self, op, other):
- return self.comparator.reverse_operate(op, other)
+ return op(other, self.comparator)
def hasparent(self, item, optimistic=False):
"""Return the boolean value of a `hasparent` flag attached to the given item.
raise NotImplementedError()
- def compare(self, value):
+ def compare(self, operator, value):
"""Return a compare operation for the columns represented by
this ``MapperProperty`` to the given value, which may be a
- column value or an instance.
+ column value or an instance. 'operator' is an operator from
+ the operators module, or from sql.Comparator.
+
+ By default uses the PropComparator attached to this MapperProperty
+ under the attribute name "comparator".
"""
- raise NotImplementedError()
-
+ return operator(self.comparator, value)
class PropComparator(sql.Comparator):
"""defines comparison operations for MapperProperty objects"""
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
-import sets, random
+import operator
from sqlalchemy.orm.interfaces import *
__all__ = ['ColumnProperty', 'CompositeProperty', 'PropertyLoader', 'BackRef']
def merge(self, session, source, dest, _recursive):
setattr(dest, self.key, getattr(source, self.key, None))
- def compare(self, value, op='=='):
- return self.comparator == value
-
def get_col_value(self, column, value):
return value
def __init__(self, class_, *columns, **kwargs):
super(CompositeProperty, self).__init__(*columns, **kwargs)
self.composite_class = class_
- self.comparator = None
+ self.comparator = CompositeProperty.Comparator(self)
def copy(self):
return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
if a is column:
setattr(obj, b, value)
- def compare(self, value, op='=='):
- # TODO: build into operator framework
- if op == '==':
- return sql.and_([a==b for a, b in zip(self.columns, value.__colset__())])
- elif op == '!=':
- return sql.or_([a!=b for a, b in zip(self.columns, value.__colset__())])
-
def get_col_value(self, column, value):
for a, b in zip(self.columns, value.__colset__()):
if a is column:
return b
-
+ class Comparator(PropComparator):
+ def __eq__(self, other):
+ if other is None:
+ return sql.and_(*[a==None for a in self.prop.columns])
+ else:
+ return sql.and_(*[a==b for a, b in zip(self.prop.columns, other.__colset__())])
+
+ def __ne__(self, other):
+ return sql.or_(*[a!=b for a, b in zip(self.prop.columns, other.__colset__())])
+
class PropertyLoader(StrategizedProperty):
"""Describes an object property that holds a single item or list
of items that correspond to a related database table.
self.remote_side = util.to_set(remote_side)
self.enable_typechecks = enable_typechecks
self._parent_join_cache = {}
- self.comparator = None
+ self.comparator = PropertyLoader.Comparator(self)
if cascade is not None:
self.cascade = mapperutil.CascadeOptions(cascade)
self.backref = backref
self.is_backref = is_backref
- def compare(self, value, value_is_parent=False, op='=='):
- if op == '==':
- # optimized operation for ==, uses a lazy clause.
- (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent)
- bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
-
- class Visitor(sql.ClauseVisitor):
- def visit_bindparam(s, bindparam):
- mapper = value_is_parent and self.parent or self.mapper
- bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key])
- Visitor().traverse(criterion)
- return criterion
+ class Comparator(PropComparator):
+ def __eq__(self, other):
+ if other is None:
+ return ~sql.exists([1], self.prop.primaryjoin)
+ else:
+ return self.prop._optimized_compare(other)
+
+ def __ne__(self, other):
+ j = self.prop.primaryjoin
+ if self.prop.secondaryjoin:
+ j = j & self.prop.secondaryjoin
+ return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
+
+ def compare(self, op, value, value_is_parent=False):
+ if op == operator.eq:
+ if value is None:
+ return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin)
+ else:
+ return self._optimized_compare(value, value_is_parent=value_is_parent)
else:
- # TODO: build expressions like these into operator framework
- return sql.and_(*[x==y for (x, y) in zip(self.mapper.primary_key, self.mapper.primary_key_from_instance(value))])
-
+ return op(self.comparator, value)
+
+ def _optimized_compare(self, value, value_is_parent=False):
+ # optimized operation for ==, uses a lazy clause.
+ (criterion, lazybinds, rev) = strategies.LazyLoader._create_lazy_clause(self, reverse_direction=not value_is_parent)
+ bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+
+ class Visitor(sql.ClauseVisitor):
+ def visit_bindparam(s, bindparam):
+ mapper = value_is_parent and self.parent or self.mapper
+ bindparam.value = mapper.get_attr_by_column(value, bind_to_col[bindparam.key])
+ Visitor().traverse(criterion)
+ return criterion
+
private = property(lambda s:s.cascade.delete_orphan)
def create_strategy(self):
from sqlalchemy import sql, util, exceptions, sql_util, logging, schema
from sqlalchemy.orm import mapper, class_mapper, object_mapper
from sqlalchemy.orm.interfaces import OperationContext
+import operator
__all__ = ['Query', 'QueryContext', 'SelectionContext']
mapper = object_mapper(instance)
prop = mapper.get_property(property, resolve_synonyms=True)
target = prop.mapper
- criterion = prop.compare(instance, value_is_parent=True)
+ criterion = prop.compare(operator.eq, instance, value_is_parent=True)
return Query(target, **kwargs).filter(criterion)
query_from_parent = classmethod(query_from_parent)
raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__))
else:
prop = mapper.get_property(property, resolve_synonyms=True)
- return self.filter(prop.compare(instance, value_is_parent=True))
+ return self.filter(prop.compare(operator.eq, instance, value_is_parent=True))
def add_entity(self, entity):
"""add a mapped entity to the list of result columns to be returned.
for key, value in kwargs.iteritems():
prop = joinpoint.get_property(key, resolve_synonyms=True)
- c = prop.compare(value)
+ c = prop.compare(operator.eq, value)
if alias is not None:
sql_util.ClauseAdapter(alias).traverse(c)
for key, value in params.iteritems():
(keys, prop) = self._locate_prop(key, start=start)
if isinstance(prop, properties.PropertyLoader):
- c = prop.compare(value) & self.join_via(keys[:-1])
+ c = prop.compare(operator.eq, value) & self.join_via(keys[:-1])
else:
- c = prop.compare(value) & self.join_via(keys)
+ c = prop.compare(operator.eq, value) & self.join_via(keys)
if clause is None:
clause = c
else:
"""A class that knows how to traverse and visit
``ClauseElements``.
- Each ``ClauseElement``'s accept_visitor() method will call a
- corresponding visit_XXXX() method here. Traversal of a
+ Calls visit_XXX() methods dynamically generated for each particualr
+ ``ClauseElement`` subclass encountered. Traversal of a
hierarchy of ``ClauseElements`` is achieved via the
``traverse()`` method, which is passed the lead
``ClauseElement``.
(column_collections=False) or to return Schema-level items
(schema_visitor=True).
+ ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
+ operation, which will produce a copy of a given ``ClauseElement``
+ structure while at the same time allowing ``ClauseVisitor`` subclasses
+ to modify the new structure in-place.
+
"""
__traverse_options__ = {}
edges = Table('edges', metadata,
Column('id', Integer, primary_key=True),
- Column('graph_id', Integer, ForeignKey('graphs.id'), nullable=False),
+ Column('graph_id', Integer, nullable=False),
+ Column('graph_version_id', Integer, nullable=False),
Column('x1', Integer),
Column('y1', Integer),
Column('x2', Integer),
- Column('y2', Integer))
-
+ Column('y2', Integer),
+ ForeignKeyConstraint(['graph_id', 'graph_version_id'], ['graphs.id', 'graphs.version_id'])
+ )
+
def test_basic(self):
class Point(object):
def __init__(self, x, y):
assert e1.end == e2.end
self.assert_sql_count(testbase.db, go, 1)
+ # test comparison of CompositeProperties to their object instances
+ g = sess.query(Graph).get([1, 1])
+ assert sess.query(Edge).filter(Edge.start==Point(3, 4)).one() is g.edges[0]
+
+ assert sess.query(Edge).filter(Edge.start!=Point(3, 4)).first() is g.edges[1]
+
+ assert sess.query(Edge).filter(Edge.start==None).all() == []
+
+
def test_pk(self):
"""test using a composite type as a primary key"""
def setup_mappers(self):
mapper(User, users, properties={
- 'addresses':relation(Address),
+ 'addresses':relation(Address, backref='user'),
'orders':relation(Order, backref='user'), # o2m, m2o
})
mapper(Address, addresses)
def test_onefilter(self):
assert [User(id=8), User(id=9)] == create_session().query(User).filter(User.name.endswith('ed')).all()
+ def test_contains(self):
+ """test comparing a collection to an object instance."""
+
+ sess = create_session()
+ address = sess.query(Address).get(3)
+ assert [User(id=8)] == sess.query(User).filter(User.addresses==address).all()
+
+ assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all()
+
+ assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
+
+ def test_contains_m2m(self):
+ sess = create_session()
+ item = sess.query(Item).get(3)
+ assert [Order(id=1), Order(id=2), Order(id=3)] == sess.query(Order).filter(Order.items==item).all()
+ assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(Order.items!=item).all()
+
+ def test_has(self):
+ """test scalar comparison to an object instance"""
+
+ sess = create_session()
+ user = sess.query(User).get(8)
+ assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user==user).all()
+
+ assert [Address(id=1), Address(id=5)] == sess.query(Address).filter(Address.user!=user).all()
+
+
class CountTest(QueryTest):
def test_basic(self):
assert 4 == create_session().query(User).count()