From: Mike Bayer Date: Tue, 24 Jul 2007 22:00:19 +0000 (+0000) Subject: - added has(), like any() but for scalars X-Git-Tag: rel_0_4_6~29 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=2da4d1179265ed83e69b927ec75dbfeb3ad6d802;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - added has(), like any() but for scalars - added **kwargs to has() and any(), criterion is optional; generate equality criterion against the related table (since we know the related property when has() and any() are used), i.e. filter(Address.user.has(name='jack')) equivalent to filter(Address.user.has(User.name=='jack')) - added "from_joinpoint=False" arg to join()/outerjoin(). yes, I know join() is getting a little crazy, but this flag is needed when you want to keep building along a line of aliased joins, adding query criterion for each alias in the chain. self-referential unit test added. - fixed basic_tree example a little bit --- diff --git a/examples/adjacencytree/basic_tree.py b/examples/adjacencytree/basic_tree.py index 9676fae89c..9f937315b8 100644 --- a/examples/adjacencytree/basic_tree.py +++ b/examples/adjacencytree/basic_tree.py @@ -1,7 +1,9 @@ """a basic Adjacency List model tree.""" from sqlalchemy import * +from sqlalchemy.orm import * from sqlalchemy.util import OrderedDict +from sqlalchemy.orm.collections import attribute_mapped_collection metadata = MetaData('sqlite:///', echo=True) @@ -11,17 +13,10 @@ trees = Table('treenodes', metadata, Column('node_name', String(50), nullable=False), ) -class NodeList(OrderedDict): - """subclasses OrderedDict to allow usage as a list-based property.""" - def append(self, node): - self[node.name] = node - def __iter__(self): - return iter(self.values()) class TreeNode(object): """a rich Tree class which includes path-based operations""" def __init__(self, name): - self.children = NodeList() self.name = name self.parent = None self.id = None @@ -30,7 +25,7 @@ class TreeNode(object): if isinstance(node, str): node = TreeNode(node) node.parent = self - self.children.append(node) + self.children[node.name] = node def __repr__(self): return self._getstring(0, False) def __str__(self): @@ -47,7 +42,7 @@ mapper(TreeNode, trees, properties=dict( id=trees.c.node_id, name=trees.c.node_name, parent_id=trees.c.parent_node_id, - children=relation(TreeNode, cascade="all", backref=backref("parent", remote_side=[trees.c.node_id]), collection_class=NodeList), + children=relation(TreeNode, cascade="all", backref=backref("parent", remote_side=[trees.c.node_id]), collection_class=attribute_mapped_collection('name')), )) print "\n\n\n----------------------------" diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index e583504d1f..8028f0fc81 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -85,11 +85,11 @@ class InstrumentedAttribute(interfaces.PropComparator): def clause_element(self): return self.comparator.clause_element() - def operate(self, op, other): - return op(self.comparator, other) + def operate(self, op, other, **kwargs): + return op(self.comparator, other, **kwargs) - def reverse_operate(self, op, other): - return op(other, self.comparator) + def reverse_operate(self, op, other, **kwargs): + return op(other, self.comparator, **kwargs) def hasparent(self, item, optimistic=False): """Return the boolean value of a `hasparent` flag attached to the given item. diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 42ae5c8ddf..30a6088908 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -58,6 +58,8 @@ class MapperExtension(object): The return value of this method is used as the result of ``query.get_by()`` if the value is anything other than EXT_PASS. + + DEPRECATED. """ return EXT_PASS @@ -68,6 +70,8 @@ class MapperExtension(object): The return value of this method is used as the result of ``query.select_by()`` if the value is anything other than EXT_PASS. + + DEPRECATED. """ return EXT_PASS @@ -78,6 +82,8 @@ class MapperExtension(object): The return value of this method is used as the result of ``query.select()`` if the value is anything other than EXT_PASS. + + DEPRECATED. """ return EXT_PASS @@ -344,10 +350,14 @@ class PropComparator(sql.ColumnOperators): return a.contains(b) contains_op = staticmethod(contains_op) - def any_op(a, b): - return a.any(b) + def any_op(a, b, **kwargs): + return a.any(b, **kwargs) any_op = staticmethod(any_op) + def has_op(a, b, **kwargs): + return a.has(b, **kwargs) + has_op = staticmethod(has_op) + def __init__(self, prop): self.prop = prop @@ -355,9 +365,32 @@ class PropComparator(sql.ColumnOperators): """return true if this collection contains other""" return self.operate(PropComparator.contains_op, other) - def any(self, criterion): - """return true if this collection contains any member that meets the given criterion""" - return self.operate(PropComparator.any_op, criterion) + def any(self, criterion=None, **kwargs): + """return true if this collection contains any member that meets the given criterion. + + criterion + an optional ClauseElement formulated against the member class' table or attributes. + + \**kwargs + key/value pairs corresponding to member class attribute names which will be compared + via equality to the corresponding values. + """ + + return self.operate(PropComparator.any_op, criterion, **kwargs) + + def has(self, criterion=None, **kwargs): + """return true if this element references a member which meets the given criterion. + + + criterion + an optional ClauseElement formulated against the member class' table or attributes. + + \**kwargs + key/value pairs corresponding to member class attribute names which will be compared + via equality to the corresponding values. + """ + + return self.operate(PropComparator.has_op, criterion, **kwargs) class StrategizedProperty(MapperProperty): """A MapperProperty which uses selectable strategies to affect diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 3985323604..e0500f2bd1 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -168,7 +168,7 @@ class PropertyLoader(StrategizedProperty): return ~sql.exists([1], self.prop.primaryjoin) elif self.prop.uselist: if not hasattr(other, '__iter__'): - raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object") + raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object.") else: j = self.prop.primaryjoin if self.prop.secondaryjoin: @@ -182,17 +182,37 @@ class PropertyLoader(StrategizedProperty): else: return self.prop._optimized_compare(other) - def any(self, criterion): + def any(self, criterion=None, **kwargs): if not self.prop.uselist: - raise exceptions.InvalidRequestError("'any' not implemented for scalar attributes") + raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().") j = self.prop.primaryjoin if self.prop.secondaryjoin: j = j & self.prop.secondaryjoin + for k in kwargs: + crit = (getattr(self.prop.mapper.class_, k) == kwargs[k]) + if criterion is None: + criterion = crit + else: + criterion = criterion & crit + return sql.exists([1], j & criterion) + + def has(self, criterion=None, **kwargs): + if self.prop.uselist: + raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().") + j = self.prop.primaryjoin + if self.prop.secondaryjoin: + j = j & self.prop.secondaryjoin + for k in kwargs: + crit = (getattr(self.prop.mapper.class_, k) == kwargs[k]) + if criterion is None: + criterion = crit + else: + criterion = criterion & crit return sql.exists([1], j & criterion) def contains(self, other): if not self.prop.uselist: - raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes") + raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==") clause = self.prop._optimized_compare(other) j = self.prop.primaryjoin diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index da5bc753fb..7b3f27a614 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -309,7 +309,7 @@ class Query(object): def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True): if start is None: start = self._joinpoint - + clause = self._from_obj[-1] currenttables = [clause] @@ -321,7 +321,7 @@ class Query(object): mapper = start - alias = None + alias = self._aliases for key in util.to_list(keys): prop = mapper.get_property(key, resolve_synonyms=True) if prop._is_self_referential() and not create_aliases: @@ -444,7 +444,7 @@ class Query(object): q._group_by = q._group_by + util.to_list(criterion) return q - def join(self, prop, id=None, aliased=False): + def join(self, prop, id=None, aliased=False, from_joinpoint=False): """create a join of this ``Query`` object's criterion to a relationship and return the newly resulting ``Query``. @@ -452,9 +452,9 @@ class Query(object): property names. """ - return self._join(prop, id=id, outerjoin=False, aliased=aliased) + return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint) - def outerjoin(self, prop, id=None, aliased=False): + def outerjoin(self, prop, id=None, aliased=False, from_joinpoint=False): """create a left outer join of this ``Query`` object's criterion to a relationship and return the newly resulting ``Query``. @@ -462,10 +462,10 @@ class Query(object): property names. """ - return self._join(prop, id=id, outerjoin=True, aliased=aliased) + return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint) - def _join(self, prop, id, outerjoin, aliased): - (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=self.mapper, create_aliases=aliased) + def _join(self, prop, id, outerjoin, aliased, from_joinpoint): + (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased) q = self._clone() q._from_obj = [clause] q._joinpoint = mapper diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index cf6c2b1a26..326fb93e56 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -1137,10 +1137,10 @@ class Operators(object): def clause_element(self): raise NotImplementedError() - def operate(self, op, *other): + def operate(self, op, *other, **kwargs): raise NotImplementedError() - def reverse_operate(self, op, *other): + def reverse_operate(self, op, *other, **kwargs): raise NotImplementedError() class ColumnOperators(Operators): diff --git a/test/orm/query.py b/test/orm/query.py index d4a5120eaf..ea5cc64def 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -6,40 +6,6 @@ from sqlalchemy.orm import * from testlib import * from fixtures import * -class Base(object): - def __init__(self, **kwargs): - for k in kwargs: - setattr(self, k, kwargs[k]) - - def __ne__(self, other): - return not self.__eq__(other) - - def __eq__(self, other): - """'passively' compare this object to another. - - only look at attributes that are present on the source object. - - """ - # use __dict__ to avoid instrumented properties - for attr in self.__dict__.keys(): - if attr[0] == '_': - continue - value = getattr(self, attr) - if hasattr(value, '__iter__') and not isinstance(value, basestring): - if len(value) == 0: - continue - for (us, them) in zip(value, getattr(other, attr)): - if us != them: - return False - else: - continue - else: - if value is not None: - if value != getattr(other, attr): - return False - else: - return True - class QueryTest(ORMTest): keep_mappers = True keep_data = True @@ -294,9 +260,21 @@ class FilterTest(QueryTest): def test_any(self): sess = create_session() - address = sess.query(Address).get(3) + assert [User(id=8), User(id=9)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'))).all() + + assert [User(id=8)] == sess.query(User).filter(User.addresses.any(Address.email_address.like('%ed%'), id=4)).all() + + assert [User(id=9)] == sess.query(User).filter(User.addresses.any(email_address='fred@fred.com')).all() + + def test_has(self): + sess = create_session() + assert [Address(id=5)] == sess.query(Address).filter(Address.user.has(name='fred')).all() + assert [Address(id=2), Address(id=3), Address(id=4), Address(id=5)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'))).all() + + assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all() + def test_contains_m2m(self): sess = create_session() item = sess.query(Item).get(3) @@ -304,7 +282,7 @@ class FilterTest(QueryTest): assert [Order(id=4), Order(id=5)] == sess.query(Order).filter(~Order.items.contains(item)).all() - def test_has(self): + def test_comparison(self): """test scalar comparison to an object instance""" sess = create_session() @@ -729,7 +707,9 @@ class SelfReferentialJoinTest(ORMTest): self.children.append(node) mapper(Node, nodes, properties={ - 'children':relation(Node, lazy=True, join_depth=3) + 'children':relation(Node, lazy=True, join_depth=3, + backref=backref('parent', remote_side=[nodes.c.id]) + ) }) sess = create_session() n1 = Node(data='n1') @@ -751,6 +731,10 @@ class SelfReferentialJoinTest(ORMTest): node = sess.query(Node).join(['children', 'children'], aliased=True).filter_by(data='n122').first() assert node.data=='n1' + + node = sess.query(Node).filter_by(data='n122').join('parent', aliased=True).filter_by(data='n12').\ + join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first() + assert node.data == 'n122' class ExternalColumnsTest(QueryTest): keep_mappers = False