From: Mike Bayer Date: Mon, 14 Feb 2011 01:20:34 +0000 (-0500) Subject: - Association proxy now has correct behavior for X-Git-Tag: rel_0_7b2~1^2~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2722035809364af9d6ea533241d34935ca17e6af;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Association proxy now has correct behavior for any(), has(), and contains() when proxying a many-to-one scalar attribute to a one-to-many collection (i.e. the reverse of the 'typical' association proxy use case) [ticket:2054] --- diff --git a/CHANGES b/CHANGES index 7a2e961b2f..7ec279ac69 100644 --- a/CHANGES +++ b/CHANGES @@ -22,6 +22,13 @@ CHANGES metadata.create_all() and metadata.drop_all(), including "checkfirst" logic. [ticket:2055] +- ext + - Association proxy now has correct behavior for + any(), has(), and contains() when proxying + a many-to-one scalar attribute to a one-to-many + collection (i.e. the reverse of the 'typical' + association proxy use case) [ticket:2054] + 0.7.0b1 ======= - Detailed descriptions of each change below are diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 969f60326a..31bfa90ff7 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -23,7 +23,9 @@ from sqlalchemy.sql import not_ def association_proxy(target_collection, attr, **kw): - """Return a Python property implementing a view of *attr* over a collection. + """Return a Python property implementing a view of a target + attribute which references an attribute on members of the + target. Implements a read/write view over an instance's *target_collection*, extracting *attr* from each member of the collection. The property acts @@ -35,16 +37,19 @@ def association_proxy(target_collection, attr, **kw): Unlike the list comprehension, the collection returned by the property is always in sync with *target_collection*, and mutations made to either collection will be reflected in both. + + The association proxy also works with scalar attributes, which in + turn reference scalar attributes or collections. Implements a Python property representing a relationship as a collection of - simpler values. The proxied property will mimic the collection type of + simpler values, or a scalar value. The proxied property will mimic the collection type of the target (list, dict or set), or, in the case of a one to one relationship, a simple scalar value. :param target_collection: Name of the relationship attribute we'll proxy to, usually created with :func:`~sqlalchemy.orm.relationship`. - :param attr: Attribute on the associated instances we'll proxy for. + :param attr: Attribute on the associated instance or instances we'll proxy for. For example, given a target collection of [obj1, obj2], a list created by this proxy property would look like [getattr(obj1, *attr*), @@ -75,7 +80,7 @@ def association_proxy(target_collection, attr, **kw): situation. :param \*\*kw: Passes along any other keyword arguments to - :class:`AssociationProxy`. + :class:`.AssociationProxy`. """ return AssociationProxy(target_collection, attr, **kw) @@ -85,7 +90,8 @@ class AssociationProxy(object): """A descriptor that presents a read/write view of an object attribute.""" def __init__(self, target_collection, attr, creator=None, - getset_factory=None, proxy_factory=None, proxy_bulk_set=None): + getset_factory=None, proxy_factory=None, + proxy_bulk_set=None): """Arguments are: target_collection @@ -137,7 +143,6 @@ class AssociationProxy(object): self.proxy_factory = proxy_factory self.proxy_bulk_set = proxy_bulk_set - self.scalar = None self.owning_class = None self.key = '_%s_%s_%s' % ( type(self).__name__, target_collection, id(self)) @@ -147,23 +152,28 @@ class AssociationProxy(object): return (orm.class_mapper(self.owning_class). get_property(self.target_collection)) - @property + @util.memoized_property def target_class(self): """The class the proxy is attached to.""" return self._get_property().mapper.class_ - def _target_is_scalar(self): - return not self._get_property().uselist + @util.memoized_property + def scalar(self): + scalar = not self._get_property().uselist + if scalar: + self._initialize_scalar_accessors() + return scalar + + @util.memoized_property + def _value_is_scalar(self): + return not self._get_property().\ + mapper.get_property(self.value_attr).uselist def __get__(self, obj, class_): if self.owning_class is None: self.owning_class = class_ and class_ or type(obj) if obj is None: return self - elif self.scalar is None: - self.scalar = self._target_is_scalar() - if self.scalar: - self._initialize_scalar_accessors() if self.scalar: return self._scalar_get(getattr(obj, self.target_collection)) @@ -183,10 +193,6 @@ class AssociationProxy(object): def __set__(self, obj, values): if self.owning_class is None: self.owning_class = type(obj) - if self.scalar is None: - self.scalar = self._target_is_scalar() - if self.scalar: - self._initialize_scalar_accessors() if self.scalar: creator = self.creator and self.creator or self.target_class @@ -278,13 +284,35 @@ class AssociationProxy(object): return self._get_property().comparator def any(self, criterion=None, **kwargs): - return self._comparator.any(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) + if self._value_is_scalar: + value_expr = getattr(self.target_class, self.value_attr).has(criterion, **kwargs) + else: + value_expr = getattr(self.target_class, self.value_attr).any(criterion, **kwargs) + + # check _value_is_scalar here, otherwise + # we're scalar->scalar - call .any() so that + # the "can't call any() on a scalar" msg is raised. + if self.scalar and not self._value_is_scalar: + return self._comparator.has( + value_expr + ) + else: + return self._comparator.any( + value_expr + ) def has(self, criterion=None, **kwargs): - return self._comparator.has(getattr(self.target_class, self.value_attr).has(criterion, **kwargs)) + return self._comparator.has( + getattr(self.target_class, self.value_attr).has(criterion, **kwargs) + ) def contains(self, obj): - return self._comparator.any(**{self.value_attr: obj}) + if self.scalar and not self._value_is_scalar: + return self._comparator.has( + getattr(self.target_class, self.value_attr).contains(obj) + ) + else: + return self._comparator.any(**{self.value_attr: obj}) def __eq__(self, obj): return self._comparator.has(**{self.value_attr: obj}) diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index c319863c69..186c75a69e 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -43,20 +43,6 @@ class ObjectCollection(object): return iter(self.values) -class Parent(object): - kids = association_proxy('children', 'name') - def __init__(self, name): - self.name = name - -class Child(object): - def __init__(self, name): - self.name = name - -class KVChild(object): - def __init__(self, name, value): - self.name = name - self.value = value - class _CollectionOperations(TestBase): def setup(self): collection_class = self.collection_class @@ -909,6 +895,19 @@ class LazyLoadTest(TestBase): self.assert_(p._children is not None) +class Parent(object): + def __init__(self, name): + self.name = name + +class Child(object): + def __init__(self, name): + self.name = name + +class KVChild(object): + def __init__(self, name, value): + self.name = name + self.value = value + class ReconstitutionTest(TestBase): def setup(self): @@ -928,6 +927,7 @@ class ReconstitutionTest(TestBase): self.metadata = metadata self.parents = parents self.children = children + Parent.kids = association_proxy('children', 'name') def teardown(self): self.metadata.drop_all() @@ -1015,15 +1015,26 @@ class ComparatorTest(_base.MappedTest): @classmethod def define_tables(cls, metadata): - Table('userkeywords', metadata, Column('keyword_id', Integer, - ForeignKey('keywords.id'), primary_key=True), - Column('user_id', Integer, ForeignKey('users.id'))) - Table('users', metadata, Column('id', Integer, + Table('userkeywords', metadata, + Column('keyword_id', Integer,ForeignKey('keywords.id'), primary_key=True), + Column('user_id', Integer, ForeignKey('users.id')) + ) + Table('users', metadata, + Column('id', Integer, + primary_key=True, test_needs_autoincrement=True), + Column('name', String(64)), + Column('singular_id', Integer, ForeignKey('singular.id')) + ) + Table('keywords', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('name', String(64))) - Table('keywords', metadata, Column('id', Integer, + Column('keyword', String(64)), + Column('singular_id', Integer, ForeignKey('singular.id')) + ) + Table('singular', metadata, + Column('id', Integer, primary_key=True, test_needs_autoincrement=True), - Column('keyword', String(64))) + ) @classmethod def setup_classes(cls): @@ -1031,13 +1042,21 @@ class ComparatorTest(_base.MappedTest): def __init__(self, name): self.name = name + # o2m -> m2o + # uselist -> nonuselist keywords = association_proxy('user_keywords', 'keyword', creator=lambda k: UserKeyword(keyword=k)) + # m2o -> o2m + # nonuselist -> uselist + singular_keywords = association_proxy('singular', 'keywords') + class Keyword(_base.ComparableEntity): def __init__(self, keyword): self.keyword = keyword + # o2o -> m2o + # nonuselist -> nonuselist user = association_proxy('user_keyword', 'user') class UserKeyword(_base.ComparableEntity): @@ -1045,34 +1064,45 @@ class ComparatorTest(_base.MappedTest): self.user = user self.keyword = keyword + class Singular(_base.ComparableEntity): + def __init__(self, value=None): + self.value = value + @classmethod @testing.resolve_artifact_names def setup_mappers(cls): - mapper(User, users) - mapper(Keyword, keywords, properties={'user_keyword' - : relationship(UserKeyword, uselist=False)}) - mapper(UserKeyword, userkeywords, properties={'user' - : relationship(User, backref='user_keywords'), 'keyword' - : relationship(Keyword)}) + mapper(User, users, properties={ + 'singular':relationship(Singular) + }) + mapper(Keyword, keywords, properties={ + 'user_keyword':relationship(UserKeyword, uselist=False) + }) + + mapper(UserKeyword, userkeywords, properties={ + 'user' : relationship(User, backref='user_keywords'), + 'keyword' : relationship(Keyword) + }) + mapper(Singular, singular, properties={ + 'keywords': relationship(Keyword) + }) @classmethod @testing.resolve_artifact_names def insert_data(cls): session = sessionmaker()() words = ( - 'quick', - 'brown', - 'fox', - 'jumped', - 'over', - 'the', - 'lazy', + 'quick', 'brown', + 'fox', 'jumped', 'over', + 'the', 'lazy', ) for ii in range(4): user = User('user%d' % ii) + user.singular = Singular() session.add(user) for jj in words[ii:ii + 3]: - user.keywords.append(Keyword(jj)) + k = Keyword(jj) + user.keywords.append(k) + user.singular.keywords.append(k) orphan = Keyword('orphan') orphan.user_keyword = UserKeyword(keyword=orphan, user=None) session.add(orphan) @@ -1085,7 +1115,7 @@ class ComparatorTest(_base.MappedTest): eq_(q_proxy.all(), q_direct.all()) @testing.resolve_artifact_names - def test_filter_any_kwarg(self): + def test_filter_any_kwarg_ul_nul(self): self._equivalent(self.session.query(User). filter(User.keywords.any(keyword='jumped' )), @@ -1095,7 +1125,7 @@ class ComparatorTest(_base.MappedTest): )))) @testing.resolve_artifact_names - def test_filter_has_kwarg(self): + def test_filter_has_kwarg_nul_nul(self): self._equivalent(self.session.query(Keyword). filter(Keyword.user.has(name='user2' )), @@ -1105,7 +1135,20 @@ class ComparatorTest(_base.MappedTest): )))) @testing.resolve_artifact_names - def test_filter_any_criterion(self): + def test_filter_has_kwarg_nul_ul(self): + self._equivalent( + self.session.query(User).\ + filter(User.singular_keywords.any(keyword='jumped')), + self.session.query(User).\ + filter( + User.singular.has( + Singular.keywords.any(keyword='jumped') + ) + ) + ) + + @testing.resolve_artifact_names + def test_filter_any_criterion_ul_nul(self): self._equivalent(self.session.query(User). filter(User.keywords.any(Keyword.keyword == 'jumped')), @@ -1115,7 +1158,7 @@ class ComparatorTest(_base.MappedTest): == 'jumped')))) @testing.resolve_artifact_names - def test_filter_has_criterion(self): + def test_filter_has_criterion_nul_nul(self): self._equivalent(self.session.query(Keyword). filter(Keyword.user.has(User.name == 'user2')), @@ -1125,28 +1168,54 @@ class ComparatorTest(_base.MappedTest): == 'user2')))) @testing.resolve_artifact_names - def test_filter_contains(self): + def test_filter_any_criterion_nul_ul(self): + self._equivalent( + self.session.query(User).\ + filter(User.singular_keywords.any(Keyword.keyword=='jumped')), + self.session.query(User).\ + filter( + User.singular.has( + Singular.keywords.any(Keyword.keyword=='jumped') + ) + ) + ) + + @testing.resolve_artifact_names + def test_filter_contains_ul_nul(self): self._equivalent(self.session.query(User). filter(User.keywords.contains(self.kw)), self.session.query(User). filter(User.user_keywords.any(keyword=self.kw))) @testing.resolve_artifact_names - def test_filter_eq(self): + def test_filter_contains_nul_ul(self): + self._equivalent( + self.session.query(User).filter( + User.singular_keywords.contains(self.kw) + ), + self.session.query(User).filter( + User.singular.has( + Singular.keywords.contains(self.kw) + ) + ), + ) + + @testing.resolve_artifact_names + def test_filter_eq_nul_nul(self): self._equivalent(self.session.query(Keyword).filter(Keyword.user == self.u), self.session.query(Keyword). filter(Keyword.user_keyword.has(user=self.u))) @testing.resolve_artifact_names - def test_filter_ne(self): + def test_filter_ne_nul_nul(self): self._equivalent(self.session.query(Keyword).filter(Keyword.user != self.u), self.session.query(Keyword). filter(not_(Keyword.user_keyword.has(user=self.u)))) @testing.resolve_artifact_names - def test_filter_eq_null(self): + def test_filter_eq_null_nul_nul(self): self._equivalent(self.session.query(Keyword).filter(Keyword.user == None), self.session.query(Keyword). @@ -1154,26 +1223,26 @@ class ComparatorTest(_base.MappedTest): == None))) @testing.resolve_artifact_names - def test_filter_scalar_contains_fails(self): + def test_filter_scalar_contains_fails_nul_nul(self): assert_raises(exceptions.InvalidRequestError, lambda : \ Keyword.user.contains(self.u)) @testing.resolve_artifact_names - def test_filter_scalar_any_fails(self): + def test_filter_scalar_any_fails_nul_nul(self): assert_raises(exceptions.InvalidRequestError, lambda : \ Keyword.user.any(name='user2')) @testing.resolve_artifact_names - def test_filter_collection_has_fails(self): + def test_filter_collection_has_fails_ul_nul(self): assert_raises(exceptions.InvalidRequestError, lambda : \ User.keywords.has(keyword='quick')) @testing.resolve_artifact_names - def test_filter_collection_eq_fails(self): + def test_filter_collection_eq_fails_ul_nul(self): assert_raises(exceptions.InvalidRequestError, lambda : \ User.keywords == self.kw) @testing.resolve_artifact_names - def test_filter_collection_ne_fails(self): + def test_filter_collection_ne_fails_ul_nul(self): assert_raises(exceptions.InvalidRequestError, lambda : \ User.keywords != self.kw)