From 747f5edd2192caae4ab79ff7dc9d79045b5bad0c Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 11 Apr 2012 13:03:52 -0400 Subject: [PATCH] - [feature] Added new flag to @validates include_removes. When True, collection remove and attribute del events will also be sent to the validation function, which accepts an additional argument "is_remove" when this flag is used. --- CHANGES | 7 ++++ lib/sqlalchemy/orm/mapper.py | 12 +++++- lib/sqlalchemy/orm/strategies.py | 6 +-- lib/sqlalchemy/orm/util.py | 52 ++++++++++++++++---------- test/orm/test_mapper.py | 64 ++++++++++++++++++++++++++++++-- 5 files changed, 112 insertions(+), 29 deletions(-) diff --git a/CHANGES b/CHANGES index f4d46f6732..161668e82c 100644 --- a/CHANGES +++ b/CHANGES @@ -12,6 +12,13 @@ CHANGES directives in statements. Courtesy Diana Clarke [ticket:2443] + - [feature] Added new flag to @validates + include_removes. When True, collection + remove and attribute del events + will also be sent to the validation function, + which accepts an additional argument + "is_remove" when this flag is used. + - [bug] Fixed bug whereby polymorphic_on column that's not otherwise mapped on the class would be incorrectly included diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e96b7549a9..afabac05a7 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -678,9 +678,10 @@ class Mapper(object): self._reconstructor = method event.listen(manager, 'load', _event_on_load, raw=True) elif hasattr(method, '__sa_validators__'): + include_removes = getattr(method, "__sa_include_removes__", False) for name in method.__sa_validators__: self.validators = self.validators.union( - {name : method} + {name : (method, include_removes)} ) manager.info[_INSTRUMENTOR] = self @@ -2291,7 +2292,7 @@ def reconstructor(fn): fn.__sa_reconstructor__ = True return fn -def validates(*names): +def validates(*names, **kw): """Decorate a method as a 'validator' for one or more named properties. Designates a method as a validator, a method which receives the @@ -2307,9 +2308,16 @@ def validates(*names): an assertion to avoid recursion overflows. This is a reentrant condition which is not supported. + :param \*names: list of attribute names to be validated. + :param include_removes: if True, "remove" events will be + sent as well - the validation function must accept an additional + argument "is_remove" which will be a boolean. New in 0.7.7. + """ + include_removes = kw.pop('include_removes', False) def wrap(fn): fn.__sa_validators__ = names + fn.__sa_include_removes__ = include_removes return fn return wrap diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 5f4b182d08..37980e1114 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -45,11 +45,11 @@ def _register_attribute(strategy, mapper, useobject, listen_hooks.append(single_parent_validator) if prop.key in prop.parent.validators: + fn, include_removes = prop.parent.validators[prop.key] listen_hooks.append( lambda desc, prop: mapperutil._validator_events(desc, - prop.key, - prop.parent.validators[prop.key]) - ) + prop.key, fn, include_removes) + ) if useobject: listen_hooks.append(unitofwork.track_cascade_events) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 0c5f203a72..197c0c4c15 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -68,24 +68,36 @@ class CascadeOptions(frozenset): ",".join([x for x in sorted(self)]) ) -def _validator_events(desc, key, validator): +def _validator_events(desc, key, validator, include_removes): """Runs a validation method on an attribute value to be set or appended.""" - def append(state, value, initiator): - return validator(state.obj(), key, value) + if include_removes: + def append(state, value, initiator): + return validator(state.obj(), key, value, False) - def set_(state, value, oldvalue, initiator): - return validator(state.obj(), key, value) + def set_(state, value, oldvalue, initiator): + return validator(state.obj(), key, value, False) + + def remove(state, value, initiator): + validator(state.obj(), key, value, True) + else: + def append(state, value, initiator): + return validator(state.obj(), key, value) + + def set_(state, value, oldvalue, initiator): + return validator(state.obj(), key, value) event.listen(desc, 'append', append, raw=True, retval=True) event.listen(desc, 'set', set_, raw=True, retval=True) + if include_removes: + event.listen(desc, "remove", remove, raw=True, retval=True) def polymorphic_union(table_map, typecolname, aliasname='p_union', cast_nulls=True): """Create a ``UNION`` statement used by a polymorphic mapper. See :ref:`concrete_inheritance` for an example of how this is used. - + :param table_map: mapping of polymorphic identities to :class:`.Table` objects. :param typecolname: string name of a "discriminator" column, which will be @@ -236,7 +248,7 @@ class AliasedClass(object): session.query(User, user_alias).\\ join((user_alias, User.id > user_alias.id)).\\ filter(User.name==user_alias.name) - + The resulting object is an instance of :class:`.AliasedClass`, however it implements a ``__getattribute__()`` scheme which will proxy attribute access to that of the ORM class being aliased. All classmethods @@ -244,7 +256,7 @@ class AliasedClass(object): hybrids created with the :ref:`hybrids_toplevel` extension, which will receive the :class:`.AliasedClass` as the "class" argument when classmethods are called. - + :param cls: ORM mapped entity which will be "wrapped" around an alias. :param alias: a selectable, such as an :func:`.alias` or :func:`.select` construct, which will be rendered in place of the mapped table of the @@ -259,28 +271,28 @@ class AliasedClass(object): otherwise have a column that corresponds to one on the entity. The use case for this is when associating an entity with some derived selectable such as one that uses aggregate functions:: - + class UnitPrice(Base): __tablename__ = 'unit_price' ... unit_id = Column(Integer) price = Column(Numeric) - + aggregated_unit_price = Session.query( func.sum(UnitPrice.price).label('price') ).group_by(UnitPrice.unit_id).subquery() - + aggregated_unit_price = aliased(UnitPrice, alias=aggregated_unit_price, adapt_on_names=True) - + Above, functions on ``aggregated_unit_price`` which refer to ``.price`` will return the ``fund.sum(UnitPrice.price).label('price')`` column, as it is matched on the name "price". Ordinarily, the "price" function wouldn't have any "column correspondence" to the actual ``UnitPrice.price`` column as it is not a proxy of the original. - + ``adapt_on_names`` is new in 0.7.3. - + """ def __init__(self, cls, alias=None, name=None, adapt_on_names=False): self.__mapper = _class_to_mapper(cls) @@ -447,7 +459,7 @@ class _ORMJoin(expression.Join): def join(left, right, onclause=None, isouter=False, join_to_left=True): """Produce an inner join between left and right clauses. - + :func:`.orm.join` is an extension to the core join interface provided by :func:`.sql.expression.join()`, where the left and right selectables may be not only core selectable @@ -460,7 +472,7 @@ def join(left, right, onclause=None, isouter=False, join_to_left=True): in whatever form it is passed, to the selectable passed as the left side. If False, the onclause is used as is. - + :func:`.orm.join` is not commonly needed in modern usage, as its functionality is encapsulated within that of the :meth:`.Query.join` method, which features a @@ -468,22 +480,22 @@ def join(left, right, onclause=None, isouter=False, join_to_left=True): by itself. Explicit usage of :func:`.orm.join` with :class:`.Query` involves usage of the :meth:`.Query.select_from` method, as in:: - + from sqlalchemy.orm import join session.query(User).\\ select_from(join(User, Address, User.addresses)).\\ filter(Address.email_address=='foo@bar.com') - + In modern SQLAlchemy the above join can be written more succinctly as:: - + session.query(User).\\ join(User.addresses).\\ filter(Address.email_address=='foo@bar.com') See :meth:`.Query.join` for information on modern usage of ORM level joins. - + """ return _ORMJoin(left, right, onclause, isouter, join_to_left) diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index 1c5f29b716..79ae7ff590 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -1950,10 +1950,11 @@ class DeepOptionsTest(_fixtures.FixtureTest): class ValidatorTest(_fixtures.FixtureTest): def test_scalar(self): users = self.tables.users - + canary = [] class User(fixtures.ComparableEntity): @validates('name') def validate_name(self, key, name): + canary.append((key, name)) assert name != 'fred' return name + ' modified' @@ -1963,6 +1964,7 @@ class ValidatorTest(_fixtures.FixtureTest): eq_(u1.name, 'ed modified') assert_raises(AssertionError, setattr, u1, "name", "fred") eq_(u1.name, 'ed modified') + eq_(canary, [('name', 'ed'), ('name', 'fred')]) sess.add(u1) sess.flush() sess.expunge_all() @@ -1973,9 +1975,11 @@ class ValidatorTest(_fixtures.FixtureTest): self.tables.addresses, self.classes.Address) + canary = [] class User(fixtures.ComparableEntity): @validates('addresses') def validate_address(self, key, ad): + canary.append((key, ad)) assert '@' in ad.email_address return ad @@ -1983,8 +1987,11 @@ class ValidatorTest(_fixtures.FixtureTest): mapper(Address, addresses) sess = create_session() u1 = User(name='edward') - assert_raises(AssertionError, u1.addresses.append, Address(email_address='noemail')) - u1.addresses.append(Address(id=15, email_address='foo@bar.com')) + a0 = Address(email_address='noemail') + assert_raises(AssertionError, u1.addresses.append, a0) + a1 = Address(id=15, email_address='foo@bar.com') + u1.addresses.append(a1) + eq_(canary, [('addresses', a0), ('addresses', a1)]) sess.add(u1) sess.flush() sess.expunge_all() @@ -2019,11 +2026,60 @@ class ValidatorTest(_fixtures.FixtureTest): mapper(Address, addresses) eq_( - dict((k, v.__name__) for k, v in u_m.validators.items()), + dict((k, v[0].__name__) for k, v in u_m.validators.items()), {'name':'validate_name', 'addresses':'validate_address'} ) + def test_validator_w_removes(self): + users, addresses, Address = (self.tables.users, + self.tables.addresses, + self.classes.Address) + canary = [] + class User(fixtures.ComparableEntity): + + @validates('name', include_removes=True) + def validate_name(self, key, item, remove): + canary.append((key, item, remove)) + return item + + @validates('addresses', include_removes=True) + def validate_address(self, key, item, remove): + canary.append((key, item, remove)) + return item + + mapper(User, + users, + properties={'addresses':relationship(Address)}) + mapper(Address, addresses) + + u1 = User() + u1.name = "ed" + u1.name = "mary" + del u1.name + a1, a2, a3 = Address(), Address(), Address() + u1.addresses.append(a1) + u1.addresses.remove(a1) + u1.addresses = [a1, a2] + u1.addresses = [a2, a3] + + eq_(canary, [ + ('name', 'ed', False), + ('name', 'mary', False), + ('name', 'mary', True), + # append a1 + ('addresses', a1, False), + # remove a1 + ('addresses', a1, True), + # set to [a1, a2] - this is two appends + ('addresses', a1, False), ('addresses', a2, False), + # set to [a2, a3] - this is a remove of a1, + # append of a3. the appends are first. + ('addresses', a3, False), + ('addresses', a1, True), + ] + ) + class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_kwarg_accepted(self): users, Address = self.tables.users, self.classes.Address -- 2.47.2