From 50e3847f09580d1e322fb11f54983e9a31846f19 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 2 Dec 2013 12:40:50 -0500 Subject: [PATCH] - Added new argument ``include_backrefs=True`` to the :func:`.validates` function; when set to False, a validation event will not be triggered if the event was initated as a backref to an attribute operation from the other side. [ticket:1535] - break out validation tests into an updated module test_validators --- doc/build/changelog/changelog_09.rst | 13 ++ doc/build/changelog/migration_09.rst | 43 ++++ doc/build/orm/mapper_config.rst | 47 ++++- lib/sqlalchemy/orm/mapper.py | 22 ++- lib/sqlalchemy/orm/strategies.py | 4 +- lib/sqlalchemy/orm/util.py | 31 ++- test/orm/test_mapper.py | 133 ------------- test/orm/test_validators.py | 281 +++++++++++++++++++++++++++ 8 files changed, 428 insertions(+), 146 deletions(-) create mode 100644 test/orm/test_validators.py diff --git a/doc/build/changelog/changelog_09.rst b/doc/build/changelog/changelog_09.rst index 1f72d0f26c..17116c2c4a 100644 --- a/doc/build/changelog/changelog_09.rst +++ b/doc/build/changelog/changelog_09.rst @@ -14,6 +14,19 @@ .. changelog:: :version: 0.9.0b2 + .. change:: + :tags: feature, orm, backrefs + :tickets: 1535 + + Added new argument ``include_backrefs=True`` to the + :func:`.validates` function; when set to False, a validation event + will not be triggered if the event was initated as a backref to + an attribute operation from the other side. + + .. seealso:: + + :ref:`feature_1535` + .. change:: :tags: bug, orm, collections, py3k :pullreq: github:40 diff --git a/doc/build/changelog/migration_09.rst b/doc/build/changelog/migration_09.rst index 1b0a4fc25b..213385f595 100644 --- a/doc/build/changelog/migration_09.rst +++ b/doc/build/changelog/migration_09.rst @@ -1086,6 +1086,49 @@ the ORM's versioning feature. :ticket:`2793` +.. _feature_1535: + +``include_backrefs=False`` option for ``@validates`` +--------------------------------------------------- + +The :func:`.validates` function now accepts an option ``enable_backrefs=False``, +which will bypass firing the validator for the case where the event initiated +from a backref:: + + from sqlalchemy import Column, Integer, ForeignKey + from sqlalchemy.orm import relationship, validates + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class A(Base): + __tablename__ = 'a' + + id = Column(Integer, primary_key=True) + bs = relationship("B", backref="a") + + @validates("bs") + def validate_bs(self, key, item): + print("A.bs validator") + return item + + class B(Base): + __tablename__ = 'b' + + id = Column(Integer, primary_key=True) + a_id = Column(Integer, ForeignKey('a.id')) + + @validates("a", include_backrefs=False) + def validate_a(self, key, item): + print("B.a validator") + return item + + a1 = A() + a1.bs.append(B()) # prints only "A.bs validator" + + +:ticket:`1535` + Behavioral Improvements ======================= diff --git a/doc/build/orm/mapper_config.rst b/doc/build/orm/mapper_config.rst index 37d9a33e6d..17bd31a6f4 100644 --- a/doc/build/orm/mapper_config.rst +++ b/doc/build/orm/mapper_config.rst @@ -667,7 +667,7 @@ issued when the ORM is populating the object:: assert '@' in address return address -Validators also receive collection events, when items are added to a +Validators also receive collection append events, when items are added to a collection:: from sqlalchemy.orm import validates @@ -682,6 +682,51 @@ collection:: assert '@' in address.email return address + +The validation function by default does not get emitted for collection +remove events, as the typical expectation is that a value being discarded +doesn't require validation. However, :func:`.validates` supports reception +of these events by specifying ``include_removes=True`` to the decorator. When +this flag is set, the validation function must receive an additional boolean +argument which if ``True`` indicates that the operation is a removal:: + + from sqlalchemy.orm import validates + + class User(Base): + # ... + + addresses = relationship("Address") + + @validates('addresses', include_removes=True) + def validate_address(self, key, address, is_remove): + if is_remove: + raise ValueError( + "not allowed to remove items from the collection") + else: + assert '@' in address.email + return address + +The case where mutually dependent validators are linked via a backref +can also be tailored, using the ``include_backrefs=False`` option; this option, +when set to ``False``, prevents a validation function from emitting if the +event occurs as a result of a backref:: + + from sqlalchemy.orm import validates + + class User(Base): + # ... + + addresses = relationship("Address", backref='user') + + @validates('addresses', include_backrefs=False) + def validate_address(self, key, address): + assert '@' in address.email + return address + +Above, if we were to assign to ``Address.user`` as in ``some_address.user = some_user``, +the ``validate_address()`` function would *not* be emitted, even though an append +occurs to ``some_user.addresses`` - the event is caused by a backref. + Note that the :func:`~.validates` decorator is a convenience function built on top of attribute events. An application that requires more control over configuration of attribute change behavior can make use of this system, diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 375e7b1afb..9b91d06383 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1103,11 +1103,10 @@ class Mapper(_InspectionAttr): self._reconstructor = method event.listen(manager, 'load', _event_on_load, raw=True) elif hasattr(method, '__sa_validators__'): - include_removes = getattr(method, - "__sa_include_removes__", False) + validation_opts = method.__sa_validation_opts__ for name in method.__sa_validators__: self.validators = self.validators.union( - {name: (method, include_removes)} + {name: (method, validation_opts)} ) manager.info[_INSTRUMENTOR] = self @@ -2582,13 +2581,28 @@ def validates(*names, **kw): argument "is_remove" which will be a boolean. .. versionadded:: 0.7.7 + :param include_backrefs: defaults to ``True``; if ``False``, the + validation function will not emit if the originator is an attribute + event related via a backref. This can be used for bi-directional + :func:`.validates` usage where only one validator should emit per + attribute operation. + + .. versionadded:: 0.9.0b2 + + .. seealso:: + + :ref:`simple_validators` - usage examples for :func:`.validates` """ include_removes = kw.pop('include_removes', False) + include_backrefs = kw.pop('include_backrefs', True) def wrap(fn): fn.__sa_validators__ = names - fn.__sa_include_removes__ = include_removes + fn.__sa_validation_opts__ = { + "include_removes": include_removes, + "include_backrefs": include_backrefs + } return fn return wrap diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index b04338d9c6..8226a0e0f2 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -44,10 +44,10 @@ def _register_attribute(strategy, mapper, useobject, listen_hooks.append(single_parent_validator) if prop.key in prop.parent.validators: - fn, include_removes = prop.parent.validators[prop.key] + fn, opts = prop.parent.validators[prop.key] listen_hooks.append( lambda desc, prop: orm_util._validator_events(desc, - prop.key, fn, include_removes) + prop.key, fn, **opts) ) if useobject: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 1b8f53c9d7..b866721757 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -70,24 +70,43 @@ class CascadeOptions(frozenset): ) -def _validator_events(desc, key, validator, include_removes): +def _validator_events(desc, key, validator, include_removes, include_backrefs): """Runs a validation method on an attribute value to be set or appended.""" + if not include_backrefs: + def detect_is_backref(state, initiator): + impl = state.manager[key].impl + return initiator.impl is not impl + if include_removes: def append(state, value, initiator): - return validator(state.obj(), key, value, False) + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value, False) + else: + return value def set_(state, value, oldvalue, initiator): - return validator(state.obj(), key, value, False) + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value, False) + else: + return value def remove(state, value, initiator): - validator(state.obj(), key, value, True) + if include_backrefs or not detect_is_backref(state, initiator): + validator(state.obj(), key, value, True) + else: def append(state, value, initiator): - return validator(state.obj(), key, value) + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value) + else: + return value def set_(state, value, oldvalue, initiator): - return validator(state.obj(), key, value) + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value) + else: + return value event.listen(desc, 'append', append, raw=True, retval=True) event.listen(desc, 'set', set_, raw=True, retval=True) diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index b1c9d3fb69..a3222bc8fd 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -2035,139 +2035,6 @@ class DeepOptionsTest(_fixtures.FixtureTest): x = u[0].orders[1].items[0].keywords[1] self.sql_count_(2, go) -class ValidatorTest(_fixtures.FixtureTest): - def test_scalar(self): - users = self.tables.users - canary = [] - class User(fixtures.ComparableEntity): - @validates('name') - def validate_name(self, key, name): - canary.append((key, name)) - assert name != 'fred' - return name + ' modified' - - mapper(User, users) - sess = create_session() - u1 = User(name='ed') - eq_(u1.name, 'ed modified') - assert_raises(AssertionError, setattr, u1, "name", "fred") - eq_(u1.name, 'ed modified') - eq_(canary, [('name', 'ed'), ('name', 'fred')]) - sess.add(u1) - sess.flush() - sess.expunge_all() - eq_(sess.query(User).filter_by(name='ed modified').one(), User(name='ed')) - - def test_collection(self): - users, addresses, Address = (self.tables.users, - self.tables.addresses, - self.classes.Address) - - canary = [] - class User(fixtures.ComparableEntity): - @validates('addresses') - def validate_address(self, key, ad): - canary.append((key, ad)) - assert '@' in ad.email_address - return ad - - mapper(User, users, properties={'addresses':relationship(Address)}) - mapper(Address, addresses) - sess = create_session() - u1 = User(name='edward') - a0 = Address(email_address='noemail') - assert_raises(AssertionError, u1.addresses.append, a0) - a1 = Address(id=15, email_address='foo@bar.com') - u1.addresses.append(a1) - eq_(canary, [('addresses', a0), ('addresses', a1)]) - sess.add(u1) - sess.flush() - sess.expunge_all() - eq_( - sess.query(User).filter_by(name='edward').one(), - User(name='edward', addresses=[Address(email_address='foo@bar.com')]) - ) - - def test_validators_dict(self): - users, addresses, Address = (self.tables.users, - self.tables.addresses, - self.classes.Address) - - class User(fixtures.ComparableEntity): - - @validates('name') - def validate_name(self, key, name): - assert name != 'fred' - return name + ' modified' - - @validates('addresses') - def validate_address(self, key, ad): - assert '@' in ad.email_address - return ad - - def simple_function(self, key, value): - return key, value - - u_m = mapper(User, - users, - properties={'addresses':relationship(Address)}) - mapper(Address, addresses) - - eq_( - dict((k, v[0].__name__) for k, v in list(u_m.validators.items())), - {'name':'validate_name', - 'addresses':'validate_address'} - ) - - def test_validator_w_removes(self): - users, addresses, Address = (self.tables.users, - self.tables.addresses, - self.classes.Address) - canary = [] - class User(fixtures.ComparableEntity): - - @validates('name', include_removes=True) - def validate_name(self, key, item, remove): - canary.append((key, item, remove)) - return item - - @validates('addresses', include_removes=True) - def validate_address(self, key, item, remove): - canary.append((key, item, remove)) - return item - - mapper(User, - users, - properties={'addresses':relationship(Address)}) - mapper(Address, addresses) - - u1 = User() - u1.name = "ed" - u1.name = "mary" - del u1.name - a1, a2, a3 = Address(), Address(), Address() - u1.addresses.append(a1) - u1.addresses.remove(a1) - u1.addresses = [a1, a2] - u1.addresses = [a2, a3] - - eq_(canary, [ - ('name', 'ed', False), - ('name', 'mary', False), - ('name', 'mary', True), - # append a1 - ('addresses', a1, False), - # remove a1 - ('addresses', a1, True), - # set to [a1, a2] - this is two appends - ('addresses', a1, False), ('addresses', a2, False), - # set to [a2, a3] - this is a remove of a1, - # append of a3. the appends are first. - ('addresses', a3, False), - ('addresses', a1, True), - ] - ) - class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL): def test_kwarg_accepted(self): users, Address = self.tables.users, self.classes.Address diff --git a/test/orm/test_validators.py b/test/orm/test_validators.py new file mode 100644 index 0000000000..417554f468 --- /dev/null +++ b/test/orm/test_validators.py @@ -0,0 +1,281 @@ +from test.orm import _fixtures +from sqlalchemy.testing import fixtures, assert_raises, eq_, ne_ +from sqlalchemy.orm import mapper, Session, validates, relationship +from sqlalchemy.testing.mock import Mock, call + + +class ValidatorTest(_fixtures.FixtureTest): + def test_scalar(self): + users = self.tables.users + canary = Mock() + + class User(fixtures.ComparableEntity): + @validates('name') + def validate_name(self, key, name): + canary(key, name) + ne_(name, 'fred') + return name + ' modified' + + mapper(User, users) + sess = Session() + u1 = User(name='ed') + eq_(u1.name, 'ed modified') + assert_raises(AssertionError, setattr, u1, "name", "fred") + eq_(u1.name, 'ed modified') + eq_(canary.mock_calls, [call('name', 'ed'), call('name', 'fred')]) + + sess.add(u1) + sess.commit() + + eq_( + sess.query(User).filter_by(name='ed modified').one(), + User(name='ed') + ) + + def test_collection(self): + users, addresses, Address = (self.tables.users, + self.tables.addresses, + self.classes.Address) + + canary = Mock() + class User(fixtures.ComparableEntity): + @validates('addresses') + def validate_address(self, key, ad): + canary(key, ad) + assert '@' in ad.email_address + return ad + + mapper(User, users, properties={ + 'addresses': relationship(Address)} + ) + mapper(Address, addresses) + sess = Session() + u1 = User(name='edward') + a0 = Address(email_address='noemail') + assert_raises(AssertionError, u1.addresses.append, a0) + a1 = Address(id=15, email_address='foo@bar.com') + u1.addresses.append(a1) + eq_(canary.mock_calls, [call('addresses', a0), call('addresses', a1)]) + sess.add(u1) + sess.commit() + + eq_( + sess.query(User).filter_by(name='edward').one(), + User(name='edward', addresses=[Address(email_address='foo@bar.com')]) + ) + + def test_validators_dict(self): + users, addresses, Address = (self.tables.users, + self.tables.addresses, + self.classes.Address) + + class User(fixtures.ComparableEntity): + + @validates('name') + def validate_name(self, key, name): + ne_(name, 'fred') + return name + ' modified' + + @validates('addresses') + def validate_address(self, key, ad): + assert '@' in ad.email_address + return ad + + def simple_function(self, key, value): + return key, value + + u_m = mapper(User, users, properties={ + 'addresses': relationship(Address) + } + ) + mapper(Address, addresses) + + eq_( + dict((k, v[0].__name__) for k, v in list(u_m.validators.items())), + {'name': 'validate_name', + 'addresses': 'validate_address'} + ) + + def test_validator_w_removes(self): + users, addresses, Address = (self.tables.users, + self.tables.addresses, + self.classes.Address) + canary = Mock() + class User(fixtures.ComparableEntity): + + @validates('name', include_removes=True) + def validate_name(self, key, item, remove): + canary(key, item, remove) + return item + + @validates('addresses', include_removes=True) + def validate_address(self, key, item, remove): + canary(key, item, remove) + return item + + mapper(User, users, properties={ + 'addresses': relationship(Address) + }) + mapper(Address, addresses) + + u1 = User() + u1.name = "ed" + u1.name = "mary" + del u1.name + a1, a2, a3 = Address(), Address(), Address() + u1.addresses.append(a1) + u1.addresses.remove(a1) + u1.addresses = [a1, a2] + u1.addresses = [a2, a3] + + eq_(canary.mock_calls, [ + call('name', 'ed', False), + call('name', 'mary', False), + call('name', 'mary', True), + # append a1 + call('addresses', a1, False), + # remove a1 + call('addresses', a1, True), + # set to [a1, a2] - this is two appends + call('addresses', a1, False), call('addresses', a2, False), + # set to [a2, a3] - this is a remove of a1, + # append of a3. the appends are first. + call('addresses', a3, False), + call('addresses', a1, True), + ] + ) + + def test_validator_wo_backrefs_wo_removes(self): + self._test_validator_backrefs(False, False) + + def test_validator_wo_backrefs_w_removes(self): + self._test_validator_backrefs(False, True) + + def test_validator_w_backrefs_wo_removes(self): + self._test_validator_backrefs(True, False) + + def test_validator_w_backrefs_w_removes(self): + self._test_validator_backrefs(True, True) + + def _test_validator_backrefs(self, include_backrefs, include_removes): + users, addresses = (self.tables.users, + self.tables.addresses) + canary = Mock() + class User(fixtures.ComparableEntity): + + if include_removes: + @validates('addresses', include_removes=True, + include_backrefs=include_backrefs) + def validate_address(self, key, item, remove): + canary(key, item, remove) + return item + else: + @validates('addresses', include_removes=False, + include_backrefs=include_backrefs) + def validate_address(self, key, item): + canary(key, item) + return item + + class Address(fixtures.ComparableEntity): + if include_removes: + @validates('user', include_backrefs=include_backrefs, + include_removes=True) + def validate_user(self, key, item, remove): + canary(key, item, remove) + return item + else: + @validates('user', include_backrefs=include_backrefs) + def validate_user(self, key, item): + canary(key, item) + return item + + mapper(User, users, properties={ + 'addresses': relationship(Address, backref="user") + }) + mapper(Address, addresses) + + u1 = User() + u2 = User() + a1, a2 = Address(), Address() + + # 3 append/set, two removes + u1.addresses.append(a1) + u1.addresses.append(a2) + a2.user = u2 + del a1.user + u2.addresses.remove(a2) + + # copy, so that generation of the + # comparisons don't get caught + calls = list(canary.mock_calls) + + if include_backrefs: + if include_removes: + eq_(calls, + [ + # append #1 + call('addresses', Address(), False), + + # backref for append + call('user', User(addresses=[]), False), + + # append #2 + call('addresses', Address(user=None), False), + + # backref for append + call('user', User(addresses=[]), False), + + # assign a2.user = u2 + call('user', User(addresses=[]), False), + + # backref for u1.addresses.remove(a2) + call('addresses', Address(user=None), True), + + # backref for u2.addresses.append(a2) + call('addresses', Address(user=None), False), + + # del a1.user + call('user', User(addresses=[]), True), + + # backref for u1.addresses.remove(a1) + call('addresses', Address(), True), + + # u2.addresses.remove(a2) + call('addresses', Address(user=None), True), + + # backref for a2.user = None + call('user', None, False) + ] + ) + else: + eq_(calls, + [ + call('addresses', Address()), + call('user', User(addresses=[])), + call('addresses', Address(user=None)), + call('user', User(addresses=[])), + call('user', User(addresses=[])), + call('addresses', Address(user=None)), + call('user', None) + ] + ) + else: + if include_removes: + eq_(calls, + [ + call('addresses', Address(), False), + call('addresses', Address(user=None), False), + call('user', User(addresses=[]), False), + call('user', User(addresses=[]), True), + call('addresses', Address(user=None), True) + ] + + ) + else: + eq_(calls, + [ + call('addresses', Address()), + call('addresses', Address(user=None)), + call('user', User(addresses=[])) + ] + ) -- 2.47.3