]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- association_proxy now has basic comparator methods .any(),
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 Jan 2010 20:24:27 +0000 (20:24 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 Jan 2010 20:24:27 +0000 (20:24 +0000)
.has(), .contains(), ==, !=, thanks to Scott Torborg.
[ticket:1372]

CHANGES
lib/sqlalchemy/ext/associationproxy.py
test/ext/test_associationproxy.py

diff --git a/CHANGES b/CHANGES
index 895344dc4581773ffc031aba6d983ac1dc12d6c7..28feff96783fb2dd9c844ac54de8502dc7e30c2d 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -871,6 +871,10 @@ CHANGES
       serializability and subclassing of the built in collections.
       [ticket:1259]
 
+    - association_proxy now has basic comparator methods .any(),
+      .has(), .contains(), ==, !=, thanks to Scott Torborg.
+      [ticket:1372]
+      
 - examples
     - The "query_cache" examples have been removed, and are replaced
       with a fully comprehensive approach that combines the usage of
index 4353558f82afe6513ae7efcf9fec0d982fc2d81b..b63bd9b0064a0beda77372b398f058ebb5b4566c 100644 (file)
@@ -13,6 +13,7 @@ from sqlalchemy import exceptions
 from sqlalchemy import orm
 from sqlalchemy import util
 from sqlalchemy.orm import collections
+from sqlalchemy.sql import not_
 
 
 def association_proxy(target_collection, attr, **kw):
@@ -266,6 +267,26 @@ class AssociationProxy(object):
                'no proxy_bulk_set supplied for custom '
                'collection_class implementation')
 
+    @property
+    def _comparator(self):
+        return self._get_property().comparator
+
+    def any(self, criterion=None, **kwargs):
+        return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
+    
+    def has(self, criterion=None, **kwargs):
+        return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
+
+    def contains(self, obj):
+        return self._comparator.any(**{self.value_attr: obj})
+
+    def __eq__(self, obj):
+        return self._comparator.has(**{self.value_attr: obj})
+
+    def __ne__(self, obj):
+        return not_(self.__eq__(obj))
+
+
 class _lazy_collection(object):
     def __init__(self, obj, target):
         self.ref = weakref.ref(obj)
index e7a6de5e19c78044ba1cdb7f676333019d26a3aa..81183d14a495f70cc6fa75c3b85db89206e1b9e3 100644 (file)
@@ -9,6 +9,8 @@ from sqlalchemy.ext.associationproxy import *
 from sqlalchemy.ext.associationproxy import _AssociationList
 from sqlalchemy.test import *
 from sqlalchemy.test.util import gc_collect
+from sqlalchemy.sql import not_
+from test.orm import _base
 
 
 class DictCollection(dict):
@@ -997,4 +999,170 @@ class PickleKeyFunc(object):
         self.name = name
     
     def __call__(self, obj):
-        return getattr(obj, self.name)
\ No newline at end of file
+        return getattr(obj, self.name)
+
+class ComparatorTest(_base.MappedTest):
+    run_inserts = 'once'
+    run_deletes = None
+    run_setup_mappers = 'once'
+    
+    @classmethod
+    def define_tables(cls, metadata):
+
+        Table(
+            'userkeywords', metadata,
+            Column('keyword_id', Integer, ForeignKey('keywords.id'),
+                   primary_key=True),
+            Column('user_id', Integer, ForeignKey('users.id')))
+
+        Table(
+            'users', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(64)))
+
+        Table(
+            'keywords', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('keyword', String(64)))
+
+    @classmethod
+    def setup_classes(cls):
+        class User(_base.ComparableEntity):
+            def __init__(self, name):
+                self.name = name
+            keywords = association_proxy('user_keywords', 'keyword',
+                                         creator=lambda k: UserKeyword(keyword=k))
+
+        class Keyword(_base.ComparableEntity):
+            def __init__(self, keyword):
+                self.keyword = keyword
+            user = association_proxy('user_keyword', 'user')
+
+        class UserKeyword(_base.ComparableEntity):
+            def __init__(self, user=None, keyword=None):
+                self.user = user
+                self.keyword = keyword
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def setup_mappers(cls):
+        
+        mapper(User, users)
+        mapper(Keyword, keywords, properties={
+            'user_keyword': relation(UserKeyword, uselist=False)
+        })
+        mapper(UserKeyword, userkeywords, properties={
+            'user': relation(User, backref='user_keywords'),
+            'keyword': relation(Keyword),
+        })
+
+    @classmethod
+    @testing.resolve_artifact_names
+    def insert_data(cls):
+        session = sessionmaker()()
+        words = ('quick', 'brown', 'fox', 'jumped', 'over', 'the', 'lazy')
+        for ii in range(4):
+            user = User('user%d' % ii)
+            session.add(user)
+            for jj in words[ii:ii+3]:
+                user.keywords.append(Keyword(jj))
+
+        orphan = Keyword('orphan')
+        orphan.user_keyword = UserKeyword(keyword=orphan, user=None)
+        session.add(orphan)
+        session.commit()
+
+        cls.u = user
+        cls.kw = user.keywords[0]
+        cls.session = session
+
+    def _equivalent(self, q_proxy, q_direct):
+        eq_(q_proxy.all(), q_direct.all())
+
+    @testing.resolve_artifact_names
+    def test_filter_any_kwarg(self):
+        self._equivalent(
+            self.session.query(User).\
+                filter(User.keywords.any(keyword='jumped')),
+            self.session.query(User).\
+                filter(User.user_keywords.any(
+                    UserKeyword.keyword.has(keyword='jumped'))))
+
+    @testing.resolve_artifact_names
+    def test_filter_has_kwarg(self):
+        self._equivalent(
+            self.session.query(Keyword).\
+                filter(Keyword.user.has(name='user2')),
+            self.session.query(Keyword).\
+                filter(Keyword.user_keyword.has(
+                    UserKeyword.user.has(name='user2'))))
+
+    @testing.resolve_artifact_names
+    def test_filter_any_criterion(self):
+        self._equivalent(
+            self.session.query(User).\
+                filter(User.keywords.any(Keyword.keyword == 'jumped')),
+            self.session.query(User).\
+                filter(User.user_keywords.any(
+                    UserKeyword.keyword.has(Keyword.keyword == 'jumped'))))
+
+    @testing.resolve_artifact_names
+    def test_filter_has_criterion(self):
+        self._equivalent(
+            self.session.query(Keyword).\
+                filter(Keyword.user.has(User.name == 'user2')),
+            self.session.query(Keyword).\
+                filter(Keyword.user_keyword.has(
+                    UserKeyword.user.has(User.name == 'user2'))))
+    
+    @testing.resolve_artifact_names
+    def test_filter_contains(self):
+        self._equivalent(
+            self.session.query(User).\
+                filter(User.keywords.contains(self.kw)),
+            self.session.query(User).\
+                filter(User.user_keywords.any(keyword=self.kw)))
+    
+    @testing.resolve_artifact_names
+    def test_filter_eq(self):
+        self._equivalent(
+            self.session.query(Keyword).\
+                filter(Keyword.user == self.u),
+            self.session.query(Keyword).\
+                filter(Keyword.user_keyword.has(user=self.u)))
+    
+    @testing.resolve_artifact_names
+    def test_filter_ne(self):
+        self._equivalent(
+            self.session.query(Keyword).\
+                filter(Keyword.user != self.u),
+            self.session.query(Keyword).\
+                filter(not_(Keyword.user_keyword.has(user=self.u))))
+
+    @testing.resolve_artifact_names
+    def test_filter_eq_null(self):
+        self._equivalent(
+            self.session.query(Keyword).\
+                filter(Keyword.user == None),
+            self.session.query(Keyword).\
+                filter(Keyword.user_keyword.has(UserKeyword.user == None)))
+
+    @testing.resolve_artifact_names
+    def test_filter_scalar_contains_fails(self):
+        assert_raises(exceptions.InvalidRequestError, lambda: Keyword.user.contains(self.u))
+    
+    @testing.resolve_artifact_names
+    def test_filter_scalar_any_fails(self):
+        assert_raises(exceptions.InvalidRequestError, lambda: Keyword.user.any(name='user2'))
+
+    @testing.resolve_artifact_names
+    def test_filter_collection_has_fails(self):
+        assert_raises(exceptions.InvalidRequestError, lambda: User.keywords.has(keyword='quick'))
+
+    @testing.resolve_artifact_names
+    def test_filter_collection_eq_fails(self):
+        assert_raises(exceptions.InvalidRequestError, lambda: User.keywords == self.kw)
+
+    @testing.resolve_artifact_names
+    def test_filter_collection_ne_fails(self):
+        assert_raises(exceptions.InvalidRequestError, lambda: User.keywords != self.kw)