From: Mike Bayer Date: Fri, 22 Jan 2010 20:24:27 +0000 (+0000) Subject: - association_proxy now has basic comparator methods .any(), X-Git-Tag: rel_0_6beta1~40 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=2d15d9b0d0b20190d5aa348edaf4b398263464f7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - association_proxy now has basic comparator methods .any(), .has(), .contains(), ==, !=, thanks to Scott Torborg. [ticket:1372] --- diff --git a/CHANGES b/CHANGES index 895344dc45..28feff9678 100644 --- a/CHANGES +++ b/CHANGES @@ -871,6 +871,10 @@ CHANGES serializability and subclassing of the built in collections. [ticket:1259] + - association_proxy now has basic comparator methods .any(), + .has(), .contains(), ==, !=, thanks to Scott Torborg. + [ticket:1372] + - examples - The "query_cache" examples have been removed, and are replaced with a fully comprehensive approach that combines the usage of diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 4353558f82..b63bd9b006 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -13,6 +13,7 @@ from sqlalchemy import exceptions from sqlalchemy import orm from sqlalchemy import util from sqlalchemy.orm import collections +from sqlalchemy.sql import not_ def association_proxy(target_collection, attr, **kw): @@ -266,6 +267,26 @@ class AssociationProxy(object): 'no proxy_bulk_set supplied for custom ' 'collection_class implementation') + @property + def _comparator(self): + return self._get_property().comparator + + def any(self, criterion=None, **kwargs): + return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) + + def has(self, criterion=None, **kwargs): + return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) + + def contains(self, obj): + return self._comparator.any(**{self.value_attr: obj}) + + def __eq__(self, obj): + return self._comparator.has(**{self.value_attr: obj}) + + def __ne__(self, obj): + return not_(self.__eq__(obj)) + + class _lazy_collection(object): def __init__(self, obj, target): self.ref = weakref.ref(obj) diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index e7a6de5e19..81183d14a4 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -9,6 +9,8 @@ from sqlalchemy.ext.associationproxy import * from sqlalchemy.ext.associationproxy import _AssociationList from sqlalchemy.test import * from sqlalchemy.test.util import gc_collect +from sqlalchemy.sql import not_ +from test.orm import _base class DictCollection(dict): @@ -997,4 +999,170 @@ class PickleKeyFunc(object): self.name = name def __call__(self, obj): - return getattr(obj, self.name) \ No newline at end of file + return getattr(obj, self.name) + +class ComparatorTest(_base.MappedTest): + run_inserts = 'once' + run_deletes = None + run_setup_mappers = 'once' + + @classmethod + def define_tables(cls, metadata): + + Table( + 'userkeywords', metadata, + Column('keyword_id', Integer, ForeignKey('keywords.id'), + primary_key=True), + Column('user_id', Integer, ForeignKey('users.id'))) + + Table( + 'users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(64))) + + Table( + 'keywords', metadata, + Column('id', Integer, primary_key=True), + Column('keyword', String(64))) + + @classmethod + def setup_classes(cls): + class User(_base.ComparableEntity): + def __init__(self, name): + self.name = name + keywords = association_proxy('user_keywords', 'keyword', + creator=lambda k: UserKeyword(keyword=k)) + + class Keyword(_base.ComparableEntity): + def __init__(self, keyword): + self.keyword = keyword + user = association_proxy('user_keyword', 'user') + + class UserKeyword(_base.ComparableEntity): + def __init__(self, user=None, keyword=None): + self.user = user + self.keyword = keyword + + @classmethod + @testing.resolve_artifact_names + def setup_mappers(cls): + + mapper(User, users) + mapper(Keyword, keywords, properties={ + 'user_keyword': relation(UserKeyword, uselist=False) + }) + mapper(UserKeyword, userkeywords, properties={ + 'user': relation(User, backref='user_keywords'), + 'keyword': relation(Keyword), + }) + + @classmethod + @testing.resolve_artifact_names + def insert_data(cls): + session = sessionmaker()() + words = ('quick', 'brown', 'fox', 'jumped', 'over', 'the', 'lazy') + for ii in range(4): + user = User('user%d' % ii) + session.add(user) + for jj in words[ii:ii+3]: + user.keywords.append(Keyword(jj)) + + orphan = Keyword('orphan') + orphan.user_keyword = UserKeyword(keyword=orphan, user=None) + session.add(orphan) + session.commit() + + cls.u = user + cls.kw = user.keywords[0] + cls.session = session + + def _equivalent(self, q_proxy, q_direct): + eq_(q_proxy.all(), q_direct.all()) + + @testing.resolve_artifact_names + def test_filter_any_kwarg(self): + self._equivalent( + self.session.query(User).\ + filter(User.keywords.any(keyword='jumped')), + self.session.query(User).\ + filter(User.user_keywords.any( + UserKeyword.keyword.has(keyword='jumped')))) + + @testing.resolve_artifact_names + def test_filter_has_kwarg(self): + self._equivalent( + self.session.query(Keyword).\ + filter(Keyword.user.has(name='user2')), + self.session.query(Keyword).\ + filter(Keyword.user_keyword.has( + UserKeyword.user.has(name='user2')))) + + @testing.resolve_artifact_names + def test_filter_any_criterion(self): + self._equivalent( + self.session.query(User).\ + filter(User.keywords.any(Keyword.keyword == 'jumped')), + self.session.query(User).\ + filter(User.user_keywords.any( + UserKeyword.keyword.has(Keyword.keyword == 'jumped')))) + + @testing.resolve_artifact_names + def test_filter_has_criterion(self): + self._equivalent( + self.session.query(Keyword).\ + filter(Keyword.user.has(User.name == 'user2')), + self.session.query(Keyword).\ + filter(Keyword.user_keyword.has( + UserKeyword.user.has(User.name == 'user2')))) + + @testing.resolve_artifact_names + def test_filter_contains(self): + self._equivalent( + self.session.query(User).\ + filter(User.keywords.contains(self.kw)), + self.session.query(User).\ + filter(User.user_keywords.any(keyword=self.kw))) + + @testing.resolve_artifact_names + def test_filter_eq(self): + self._equivalent( + self.session.query(Keyword).\ + filter(Keyword.user == self.u), + self.session.query(Keyword).\ + filter(Keyword.user_keyword.has(user=self.u))) + + @testing.resolve_artifact_names + def test_filter_ne(self): + self._equivalent( + self.session.query(Keyword).\ + filter(Keyword.user != self.u), + self.session.query(Keyword).\ + filter(not_(Keyword.user_keyword.has(user=self.u)))) + + @testing.resolve_artifact_names + def test_filter_eq_null(self): + self._equivalent( + self.session.query(Keyword).\ + filter(Keyword.user == None), + self.session.query(Keyword).\ + filter(Keyword.user_keyword.has(UserKeyword.user == None))) + + @testing.resolve_artifact_names + def test_filter_scalar_contains_fails(self): + assert_raises(exceptions.InvalidRequestError, lambda: Keyword.user.contains(self.u)) + + @testing.resolve_artifact_names + def test_filter_scalar_any_fails(self): + assert_raises(exceptions.InvalidRequestError, lambda: Keyword.user.any(name='user2')) + + @testing.resolve_artifact_names + def test_filter_collection_has_fails(self): + assert_raises(exceptions.InvalidRequestError, lambda: User.keywords.has(keyword='quick')) + + @testing.resolve_artifact_names + def test_filter_collection_eq_fails(self): + assert_raises(exceptions.InvalidRequestError, lambda: User.keywords == self.kw) + + @testing.resolve_artifact_names + def test_filter_collection_ne_fails(self): + assert_raises(exceptions.InvalidRequestError, lambda: User.keywords != self.kw)