]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Association proxy now has correct behavior for
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Feb 2011 01:20:34 +0000 (20:20 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Feb 2011 01:20:34 +0000 (20:20 -0500)
any(), has(), and contains() when proxying
a many-to-one scalar attribute to a one-to-many
collection (i.e. the reverse of the 'typical'
association proxy use case)  [ticket:2054]

CHANGES
lib/sqlalchemy/ext/associationproxy.py
test/ext/test_associationproxy.py

diff --git a/CHANGES b/CHANGES
index 7a2e961b2f586e37d57929dd37a8736c3d93f412..7ec279ac6911938f6e6388a2e2cb78b6500d1096 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -22,6 +22,13 @@ CHANGES
     metadata.create_all() and metadata.drop_all(), 
     including "checkfirst" logic.  [ticket:2055]
 
+- ext
+  - Association proxy now has correct behavior for
+    any(), has(), and contains() when proxying
+    a many-to-one scalar attribute to a one-to-many
+    collection (i.e. the reverse of the 'typical'
+    association proxy use case)  [ticket:2054]
+
 0.7.0b1
 =======
 - Detailed descriptions of each change below are 
index 969f60326a53922bb2181af2183a64431ffdc685..31bfa90ff72fef768054eab8ea0bab8e6d6ac894 100644 (file)
@@ -23,7 +23,9 @@ from sqlalchemy.sql import not_
 
 
 def association_proxy(target_collection, attr, **kw):
-    """Return a Python property implementing a view of *attr* over a collection.
+    """Return a Python property implementing a view of a target
+    attribute which references an attribute on members of the 
+    target.
 
     Implements a read/write view over an instance's *target_collection*,
     extracting *attr* from each member of the collection.  The property acts
@@ -35,16 +37,19 @@ def association_proxy(target_collection, attr, **kw):
     Unlike the list comprehension, the collection returned by the property is
     always in sync with *target_collection*, and mutations made to either
     collection will be reflected in both.
+    
+    The association proxy also works with scalar attributes, which in
+    turn reference scalar attributes or collections.
 
     Implements a Python property representing a relationship as a collection of
-    simpler values.  The proxied property will mimic the collection type of
+    simpler values, or a scalar value.  The proxied property will mimic the collection type of
     the target (list, dict or set), or, in the case of a one to one relationship,
     a simple scalar value.
 
     :param target_collection: Name of the relationship attribute we'll proxy to,
       usually created with :func:`~sqlalchemy.orm.relationship`.
 
-    :param attr: Attribute on the associated instances we'll proxy for.
+    :param attr: Attribute on the associated instance or instances we'll proxy for.
 
       For example, given a target collection of [obj1, obj2], a list created
       by this proxy property would look like [getattr(obj1, *attr*),
@@ -75,7 +80,7 @@ def association_proxy(target_collection, attr, **kw):
       situation.
 
     :param \*\*kw: Passes along any other keyword arguments to
-      :class:`AssociationProxy`.
+      :class:`.AssociationProxy`.
 
     """
     return AssociationProxy(target_collection, attr, **kw)
@@ -85,7 +90,8 @@ class AssociationProxy(object):
     """A descriptor that presents a read/write view of an object attribute."""
 
     def __init__(self, target_collection, attr, creator=None,
-                 getset_factory=None, proxy_factory=None, proxy_bulk_set=None):
+                 getset_factory=None, proxy_factory=None, 
+                 proxy_bulk_set=None):
         """Arguments are:
 
         target_collection
@@ -137,7 +143,6 @@ class AssociationProxy(object):
         self.proxy_factory = proxy_factory
         self.proxy_bulk_set = proxy_bulk_set
 
-        self.scalar = None
         self.owning_class = None
         self.key = '_%s_%s_%s' % (
             type(self).__name__, target_collection, id(self))
@@ -147,23 +152,28 @@ class AssociationProxy(object):
         return (orm.class_mapper(self.owning_class).
                 get_property(self.target_collection))
 
-    @property
+    @util.memoized_property
     def target_class(self):
         """The class the proxy is attached to."""
         return self._get_property().mapper.class_
 
-    def _target_is_scalar(self):
-        return not self._get_property().uselist
+    @util.memoized_property
+    def scalar(self):
+        scalar = not self._get_property().uselist
+        if scalar:
+            self._initialize_scalar_accessors()
+        return scalar
+
+    @util.memoized_property
+    def _value_is_scalar(self):
+        return not self._get_property().\
+                    mapper.get_property(self.value_attr).uselist
 
     def __get__(self, obj, class_):
         if self.owning_class is None:
             self.owning_class = class_ and class_ or type(obj)
         if obj is None:
             return self
-        elif self.scalar is None:
-            self.scalar = self._target_is_scalar()
-            if self.scalar:
-                self._initialize_scalar_accessors()
 
         if self.scalar:
             return self._scalar_get(getattr(obj, self.target_collection))
@@ -183,10 +193,6 @@ class AssociationProxy(object):
     def __set__(self, obj, values):
         if self.owning_class is None:
             self.owning_class = type(obj)
-        if self.scalar is None:
-            self.scalar = self._target_is_scalar()
-            if self.scalar:
-                self._initialize_scalar_accessors()
 
         if self.scalar:
             creator = self.creator and self.creator or self.target_class
@@ -278,13 +284,35 @@ class AssociationProxy(object):
         return self._get_property().comparator
 
     def any(self, criterion=None, **kwargs):
-        return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
+        if self._value_is_scalar:
+            value_expr = getattr(self.target_class, self.value_attr).has(criterion, **kwargs)
+        else:
+            value_expr = getattr(self.target_class, self.value_attr).any(criterion, **kwargs)
+
+        # check _value_is_scalar here, otherwise
+        # we're scalar->scalar - call .any() so that
+        # the "can't call any() on a scalar" msg is raised.
+        if self.scalar and not self._value_is_scalar:
+            return self._comparator.has(
+                    value_expr
+                )
+        else:
+            return self._comparator.any(
+                    value_expr
+                )
 
     def has(self, criterion=None, **kwargs):
-        return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs))
+        return self._comparator.has(
+                    getattr(self.target_class, self.value_attr).has(criterion, **kwargs)
+                )
 
     def contains(self, obj):
-        return self._comparator.any(**{self.value_attr: obj})
+        if self.scalar and not self._value_is_scalar:
+            return self._comparator.has(
+                getattr(self.target_class, self.value_attr).contains(obj)
+            )
+        else:
+            return self._comparator.any(**{self.value_attr: obj})
 
     def __eq__(self, obj):
         return self._comparator.has(**{self.value_attr: obj})
index c319863c690d7b0ecfba461af91064ce555133f4..186c75a69ef524d2027ec871e4762d9141a639dc 100644 (file)
@@ -43,20 +43,6 @@ class ObjectCollection(object):
         return iter(self.values)
 
 
-class Parent(object):
-    kids = association_proxy('children', 'name')
-    def __init__(self, name):
-        self.name = name
-
-class Child(object):
-    def __init__(self, name):
-        self.name = name
-
-class KVChild(object):
-    def __init__(self, name, value):
-        self.name = name
-        self.value = value
-
 class _CollectionOperations(TestBase):
     def setup(self):
         collection_class = self.collection_class
@@ -909,6 +895,19 @@ class LazyLoadTest(TestBase):
         self.assert_(p._children is not None)
 
 
+class Parent(object):
+    def __init__(self, name):
+        self.name = name
+
+class Child(object):
+    def __init__(self, name):
+        self.name = name
+
+class KVChild(object):
+    def __init__(self, name, value):
+        self.name = name
+        self.value = value
+
 class ReconstitutionTest(TestBase):
 
     def setup(self):
@@ -928,6 +927,7 @@ class ReconstitutionTest(TestBase):
         self.metadata = metadata
         self.parents = parents
         self.children = children
+        Parent.kids = association_proxy('children', 'name')
 
     def teardown(self):
         self.metadata.drop_all()
@@ -1015,15 +1015,26 @@ class ComparatorTest(_base.MappedTest):
 
     @classmethod
     def define_tables(cls, metadata):
-        Table('userkeywords', metadata, Column('keyword_id', Integer,
-              ForeignKey('keywords.id'), primary_key=True),
-              Column('user_id', Integer, ForeignKey('users.id')))
-        Table('users', metadata, Column('id', Integer,
+        Table('userkeywords', metadata, 
+          Column('keyword_id', Integer,ForeignKey('keywords.id'), primary_key=True),
+          Column('user_id', Integer, ForeignKey('users.id'))
+        )
+        Table('users', metadata, 
+            Column('id', Integer,
+              primary_key=True, test_needs_autoincrement=True),
+            Column('name', String(64)),
+            Column('singular_id', Integer, ForeignKey('singular.id'))
+        )
+        Table('keywords', metadata, 
+            Column('id', Integer,
               primary_key=True, test_needs_autoincrement=True),
-              Column('name', String(64)))
-        Table('keywords', metadata, Column('id', Integer,
+            Column('keyword', String(64)),
+            Column('singular_id', Integer, ForeignKey('singular.id'))
+        )
+        Table('singular', metadata,
+            Column('id', Integer,
               primary_key=True, test_needs_autoincrement=True),
-              Column('keyword', String(64)))
+        )
 
     @classmethod
     def setup_classes(cls):
@@ -1031,13 +1042,21 @@ class ComparatorTest(_base.MappedTest):
             def __init__(self, name):
                 self.name = name
 
+            # o2m -> m2o
+            # uselist -> nonuselist
             keywords = association_proxy('user_keywords', 'keyword',
                     creator=lambda k: UserKeyword(keyword=k))
 
+            # m2o -> o2m
+            # nonuselist -> uselist
+            singular_keywords = association_proxy('singular', 'keywords')
+
         class Keyword(_base.ComparableEntity):
             def __init__(self, keyword):
                 self.keyword = keyword
 
+            # o2o -> m2o
+            # nonuselist -> nonuselist
             user = association_proxy('user_keyword', 'user')
 
         class UserKeyword(_base.ComparableEntity):
@@ -1045,34 +1064,45 @@ class ComparatorTest(_base.MappedTest):
                 self.user = user
                 self.keyword = keyword
 
+        class Singular(_base.ComparableEntity):
+            def __init__(self, value=None):
+                self.value = value
+
     @classmethod
     @testing.resolve_artifact_names
     def setup_mappers(cls):
-        mapper(User, users)
-        mapper(Keyword, keywords, properties={'user_keyword'
-               : relationship(UserKeyword, uselist=False)})
-        mapper(UserKeyword, userkeywords, properties={'user'
-               : relationship(User, backref='user_keywords'), 'keyword'
-               : relationship(Keyword)})
+        mapper(User, users, properties={
+            'singular':relationship(Singular)
+        })
+        mapper(Keyword, keywords, properties={
+            'user_keyword':relationship(UserKeyword, uselist=False)
+        })
+
+        mapper(UserKeyword, userkeywords, properties={
+            'user' : relationship(User, backref='user_keywords'), 
+            'keyword' : relationship(Keyword)
+        })
+        mapper(Singular, singular, properties={
+            'keywords': relationship(Keyword)
+        })
 
     @classmethod
     @testing.resolve_artifact_names
     def insert_data(cls):
         session = sessionmaker()()
         words = (
-            'quick',
-            'brown',
-            'fox',
-            'jumped',
-            'over',
-            'the',
-            'lazy',
+            'quick', 'brown',
+            'fox', 'jumped', 'over',
+            'the', 'lazy',
             )
         for ii in range(4):
             user = User('user%d' % ii)
+            user.singular = Singular()
             session.add(user)
             for jj in words[ii:ii + 3]:
-                user.keywords.append(Keyword(jj))
+                k = Keyword(jj)
+                user.keywords.append(k)
+                user.singular.keywords.append(k)
         orphan = Keyword('orphan')
         orphan.user_keyword = UserKeyword(keyword=orphan, user=None)
         session.add(orphan)
@@ -1085,7 +1115,7 @@ class ComparatorTest(_base.MappedTest):
         eq_(q_proxy.all(), q_direct.all())
 
     @testing.resolve_artifact_names
-    def test_filter_any_kwarg(self):
+    def test_filter_any_kwarg_ul_nul(self):
         self._equivalent(self.session.query(User).
                     filter(User.keywords.any(keyword='jumped'
                          )),
@@ -1095,7 +1125,7 @@ class ComparatorTest(_base.MappedTest):
                          ))))
 
     @testing.resolve_artifact_names
-    def test_filter_has_kwarg(self):
+    def test_filter_has_kwarg_nul_nul(self):
         self._equivalent(self.session.query(Keyword).
                     filter(Keyword.user.has(name='user2'
                          )),
@@ -1105,7 +1135,20 @@ class ComparatorTest(_base.MappedTest):
                          ))))
 
     @testing.resolve_artifact_names
-    def test_filter_any_criterion(self):
+    def test_filter_has_kwarg_nul_ul(self):
+        self._equivalent(
+            self.session.query(User).\
+                        filter(User.singular_keywords.any(keyword='jumped')),
+            self.session.query(User).\
+                        filter(
+                            User.singular.has(
+                                Singular.keywords.any(keyword='jumped')
+                            )
+                        )
+        )
+
+    @testing.resolve_artifact_names
+    def test_filter_any_criterion_ul_nul(self):
         self._equivalent(self.session.query(User).
                     filter(User.keywords.any(Keyword.keyword
                          == 'jumped')),
@@ -1115,7 +1158,7 @@ class ComparatorTest(_base.MappedTest):
                          == 'jumped'))))
 
     @testing.resolve_artifact_names
-    def test_filter_has_criterion(self):
+    def test_filter_has_criterion_nul_nul(self):
         self._equivalent(self.session.query(Keyword).
                 filter(Keyword.user.has(User.name
                          == 'user2')),
@@ -1125,28 +1168,54 @@ class ComparatorTest(_base.MappedTest):
                          == 'user2'))))
 
     @testing.resolve_artifact_names
-    def test_filter_contains(self):
+    def test_filter_any_criterion_nul_ul(self):
+        self._equivalent(
+            self.session.query(User).\
+                        filter(User.singular_keywords.any(Keyword.keyword=='jumped')),
+            self.session.query(User).\
+                        filter(
+                            User.singular.has(
+                                Singular.keywords.any(Keyword.keyword=='jumped')
+                            )
+                        )
+        )
+
+    @testing.resolve_artifact_names
+    def test_filter_contains_ul_nul(self):
         self._equivalent(self.session.query(User).
         filter(User.keywords.contains(self.kw)),
                          self.session.query(User).
                          filter(User.user_keywords.any(keyword=self.kw)))
 
     @testing.resolve_artifact_names
-    def test_filter_eq(self):
+    def test_filter_contains_nul_ul(self):
+        self._equivalent(
+            self.session.query(User).filter(
+                            User.singular_keywords.contains(self.kw)
+            ),
+            self.session.query(User).filter(
+                            User.singular.has(
+                                Singular.keywords.contains(self.kw)
+                            )
+            ),
+        )
+
+    @testing.resolve_artifact_names
+    def test_filter_eq_nul_nul(self):
         self._equivalent(self.session.query(Keyword).filter(Keyword.user
                          == self.u),
                          self.session.query(Keyword).
                          filter(Keyword.user_keyword.has(user=self.u)))
 
     @testing.resolve_artifact_names
-    def test_filter_ne(self):
+    def test_filter_ne_nul_nul(self):
         self._equivalent(self.session.query(Keyword).filter(Keyword.user
                          != self.u),
                          self.session.query(Keyword).
                          filter(not_(Keyword.user_keyword.has(user=self.u))))
 
     @testing.resolve_artifact_names
-    def test_filter_eq_null(self):
+    def test_filter_eq_null_nul_nul(self):
         self._equivalent(self.session.query(Keyword).filter(Keyword.user
                          == None),
                          self.session.query(Keyword).
@@ -1154,26 +1223,26 @@ class ComparatorTest(_base.MappedTest):
                          == None)))
 
     @testing.resolve_artifact_names
-    def test_filter_scalar_contains_fails(self):
+    def test_filter_scalar_contains_fails_nul_nul(self):
         assert_raises(exceptions.InvalidRequestError, lambda : \
                       Keyword.user.contains(self.u))
 
     @testing.resolve_artifact_names
-    def test_filter_scalar_any_fails(self):
+    def test_filter_scalar_any_fails_nul_nul(self):
         assert_raises(exceptions.InvalidRequestError, lambda : \
                       Keyword.user.any(name='user2'))
 
     @testing.resolve_artifact_names
-    def test_filter_collection_has_fails(self):
+    def test_filter_collection_has_fails_ul_nul(self):
         assert_raises(exceptions.InvalidRequestError, lambda : \
                       User.keywords.has(keyword='quick'))
 
     @testing.resolve_artifact_names
-    def test_filter_collection_eq_fails(self):
+    def test_filter_collection_eq_fails_ul_nul(self):
         assert_raises(exceptions.InvalidRequestError, lambda : \
                       User.keywords == self.kw)
 
     @testing.resolve_artifact_names
-    def test_filter_collection_ne_fails(self):
+    def test_filter_collection_ne_fails_ul_nul(self):
         assert_raises(exceptions.InvalidRequestError, lambda : \
                       User.keywords != self.kw)