]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Added new argument ``include_backrefs=True`` to the
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Dec 2013 17:40:50 +0000 (12:40 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 2 Dec 2013 17:40:50 +0000 (12:40 -0500)
:func:`.validates` function; when set to False, a validation event
will not be triggered if the event was initated as a backref to
an attribute operation from the other side. [ticket:1535]
- break out validation tests into an updated module test_validators

doc/build/changelog/changelog_09.rst
doc/build/changelog/migration_09.rst
doc/build/orm/mapper_config.rst
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
test/orm/test_mapper.py
test/orm/test_validators.py [new file with mode: 0644]

index 1f72d0f26cd48b023dddd208c32a9df999b40ec3..17116c2c4a653bc9337374c3034e3d6347a10d0c 100644 (file)
 .. changelog::
     :version: 0.9.0b2
 
+    .. change::
+        :tags: feature, orm, backrefs
+        :tickets: 1535
+
+        Added new argument ``include_backrefs=True`` to the
+        :func:`.validates` function; when set to False, a validation event
+        will not be triggered if the event was initated as a backref to
+        an attribute operation from the other side.
+
+        .. seealso::
+
+            :ref:`feature_1535`
+
     .. change::
         :tags: bug, orm, collections, py3k
         :pullreq: github:40
index 1b0a4fc25b528dde2a16b0cc9ef6c1d33818c895..213385f595f24e50920d246b647b1b1a449d1327 100644 (file)
@@ -1086,6 +1086,49 @@ the ORM's versioning feature.
 
 :ticket:`2793`
 
+.. _feature_1535:
+
+``include_backrefs=False`` option for ``@validates``
+---------------------------------------------------
+
+The :func:`.validates` function now accepts an option ``enable_backrefs=False``,
+which will bypass firing the validator for the case where the event initiated
+from a backref::
+
+    from sqlalchemy import Column, Integer, ForeignKey
+    from sqlalchemy.orm import relationship, validates
+    from sqlalchemy.ext.declarative import declarative_base
+
+    Base = declarative_base()
+
+    class A(Base):
+        __tablename__ = 'a'
+
+        id = Column(Integer, primary_key=True)
+        bs = relationship("B", backref="a")
+
+        @validates("bs")
+        def validate_bs(self, key, item):
+            print("A.bs validator")
+            return item
+
+    class B(Base):
+        __tablename__ = 'b'
+
+        id = Column(Integer, primary_key=True)
+        a_id = Column(Integer, ForeignKey('a.id'))
+
+        @validates("a", include_backrefs=False)
+        def validate_a(self, key, item):
+            print("B.a validator")
+            return item
+
+    a1 = A()
+    a1.bs.append(B())  # prints only "A.bs validator"
+
+
+:ticket:`1535`
+
 Behavioral Improvements
 =======================
 
index 37d9a33e6d907743285d162cad726d81fc2381a0..17bd31a6f4b262bba73ab6b916e4088dfa6d11be 100644 (file)
@@ -667,7 +667,7 @@ issued when the ORM is populating the object::
             assert '@' in address
             return address
 
-Validators also receive collection events, when items are added to a
+Validators also receive collection append events, when items are added to a
 collection::
 
     from sqlalchemy.orm import validates
@@ -682,6 +682,51 @@ collection::
             assert '@' in address.email
             return address
 
+
+The validation function by default does not get emitted for collection
+remove events, as the typical expectation is that a value being discarded
+doesn't require validation.  However, :func:`.validates` supports reception
+of these events by specifying ``include_removes=True`` to the decorator.  When
+this flag is set, the validation function must receive an additional boolean
+argument which if ``True`` indicates that the operation is a removal::
+
+    from sqlalchemy.orm import validates
+
+    class User(Base):
+        # ...
+
+        addresses = relationship("Address")
+
+        @validates('addresses', include_removes=True)
+        def validate_address(self, key, address, is_remove):
+            if is_remove:
+                raise ValueError(
+                        "not allowed to remove items from the collection")
+            else:
+                assert '@' in address.email
+                return address
+
+The case where mutually dependent validators are linked via a backref
+can also be tailored, using the ``include_backrefs=False`` option; this option,
+when set to ``False``, prevents a validation function from emitting if the
+event occurs as a result of a backref::
+
+    from sqlalchemy.orm import validates
+
+    class User(Base):
+        # ...
+
+        addresses = relationship("Address", backref='user')
+
+        @validates('addresses', include_backrefs=False)
+        def validate_address(self, key, address):
+            assert '@' in address.email
+            return address
+
+Above, if we were to assign to ``Address.user`` as in ``some_address.user = some_user``,
+the ``validate_address()`` function would *not* be emitted, even though an append
+occurs to ``some_user.addresses`` - the event is caused by a backref.
+
 Note that the :func:`~.validates` decorator is a convenience function built on
 top of attribute events.   An application that requires more control over
 configuration of attribute change behavior can make use of this system,
index 375e7b1afba15ecef3ba413cb1ebf15c41eaaa65..9b91d06383c065c0d84ea1466a2eafc784a02455 100644 (file)
@@ -1103,11 +1103,10 @@ class Mapper(_InspectionAttr):
                     self._reconstructor = method
                     event.listen(manager, 'load', _event_on_load, raw=True)
                 elif hasattr(method, '__sa_validators__'):
-                    include_removes = getattr(method,
-                                            "__sa_include_removes__", False)
+                    validation_opts = method.__sa_validation_opts__
                     for name in method.__sa_validators__:
                         self.validators = self.validators.union(
-                            {name: (method, include_removes)}
+                            {name: (method, validation_opts)}
                         )
 
         manager.info[_INSTRUMENTOR] = self
@@ -2582,13 +2581,28 @@ def validates(*names, **kw):
      argument "is_remove" which will be a boolean.
 
      .. versionadded:: 0.7.7
+    :param include_backrefs: defaults to ``True``; if ``False``, the
+     validation function will not emit if the originator is an attribute
+     event related via a backref.  This can be used for bi-directional
+     :func:`.validates` usage where only one validator should emit per
+     attribute operation.
+
+     .. versionadded:: 0.9.0b2
+
+    .. seealso::
+
+      :ref:`simple_validators` - usage examples for :func:`.validates`
 
     """
     include_removes = kw.pop('include_removes', False)
+    include_backrefs = kw.pop('include_backrefs', True)
 
     def wrap(fn):
         fn.__sa_validators__ = names
-        fn.__sa_include_removes__ = include_removes
+        fn.__sa_validation_opts__ = {
+                  "include_removes": include_removes,
+                  "include_backrefs": include_backrefs
+                }
         return fn
     return wrap
 
index b04338d9c66d80dc16705f5e185843d794604b1f..8226a0e0f2fecc6d8ca41732723b4fd3fc43b2df 100644 (file)
@@ -44,10 +44,10 @@ def _register_attribute(strategy, mapper, useobject,
         listen_hooks.append(single_parent_validator)
 
     if prop.key in prop.parent.validators:
-        fn, include_removes = prop.parent.validators[prop.key]
+        fn, opts = prop.parent.validators[prop.key]
         listen_hooks.append(
             lambda desc, prop: orm_util._validator_events(desc,
-                                prop.key, fn, include_removes)
+                                prop.key, fn, **opts)
             )
 
     if useobject:
index 1b8f53c9d71c0cdd1ba2bb467c2445ae87af7178..b8667217579323154a4ed6a67443de00807ce564 100644 (file)
@@ -70,24 +70,43 @@ class CascadeOptions(frozenset):
         )
 
 
-def _validator_events(desc, key, validator, include_removes):
+def _validator_events(desc, key, validator, include_removes, include_backrefs):
     """Runs a validation method on an attribute value to be set or appended."""
 
+    if not include_backrefs:
+        def detect_is_backref(state, initiator):
+            impl = state.manager[key].impl
+            return initiator.impl is not impl
+
     if include_removes:
         def append(state, value, initiator):
-            return validator(state.obj(), key, value, False)
+            if include_backrefs or not detect_is_backref(state, initiator):
+                return validator(state.obj(), key, value, False)
+            else:
+                return value
 
         def set_(state, value, oldvalue, initiator):
-            return validator(state.obj(), key, value, False)
+            if include_backrefs or not detect_is_backref(state, initiator):
+                return validator(state.obj(), key, value, False)
+            else:
+                return value
 
         def remove(state, value, initiator):
-            validator(state.obj(), key, value, True)
+            if include_backrefs or not detect_is_backref(state, initiator):
+                validator(state.obj(), key, value, True)
+
     else:
         def append(state, value, initiator):
-            return validator(state.obj(), key, value)
+            if include_backrefs or not detect_is_backref(state, initiator):
+                return validator(state.obj(), key, value)
+            else:
+                return value
 
         def set_(state, value, oldvalue, initiator):
-            return validator(state.obj(), key, value)
+            if include_backrefs or not detect_is_backref(state, initiator):
+                return validator(state.obj(), key, value)
+            else:
+                return value
 
     event.listen(desc, 'append', append, raw=True, retval=True)
     event.listen(desc, 'set', set_, raw=True, retval=True)
index b1c9d3fb691f9a7d3a29df7386abf7a0893cef2e..a3222bc8fd039cef17eb2bea8c1ff0528b241d62 100644 (file)
@@ -2035,139 +2035,6 @@ class DeepOptionsTest(_fixtures.FixtureTest):
             x = u[0].orders[1].items[0].keywords[1]
         self.sql_count_(2, go)
 
-class ValidatorTest(_fixtures.FixtureTest):
-    def test_scalar(self):
-        users = self.tables.users
-        canary = []
-        class User(fixtures.ComparableEntity):
-            @validates('name')
-            def validate_name(self, key, name):
-                canary.append((key, name))
-                assert name != 'fred'
-                return name + ' modified'
-
-        mapper(User, users)
-        sess = create_session()
-        u1 = User(name='ed')
-        eq_(u1.name, 'ed modified')
-        assert_raises(AssertionError, setattr, u1, "name", "fred")
-        eq_(u1.name, 'ed modified')
-        eq_(canary, [('name', 'ed'), ('name', 'fred')])
-        sess.add(u1)
-        sess.flush()
-        sess.expunge_all()
-        eq_(sess.query(User).filter_by(name='ed modified').one(), User(name='ed'))
-
-    def test_collection(self):
-        users, addresses, Address = (self.tables.users,
-                                self.tables.addresses,
-                                self.classes.Address)
-
-        canary = []
-        class User(fixtures.ComparableEntity):
-            @validates('addresses')
-            def validate_address(self, key, ad):
-                canary.append((key, ad))
-                assert '@' in ad.email_address
-                return ad
-
-        mapper(User, users, properties={'addresses':relationship(Address)})
-        mapper(Address, addresses)
-        sess = create_session()
-        u1 = User(name='edward')
-        a0 = Address(email_address='noemail')
-        assert_raises(AssertionError, u1.addresses.append, a0)
-        a1 = Address(id=15, email_address='foo@bar.com')
-        u1.addresses.append(a1)
-        eq_(canary, [('addresses', a0), ('addresses', a1)])
-        sess.add(u1)
-        sess.flush()
-        sess.expunge_all()
-        eq_(
-            sess.query(User).filter_by(name='edward').one(),
-            User(name='edward', addresses=[Address(email_address='foo@bar.com')])
-        )
-
-    def test_validators_dict(self):
-        users, addresses, Address = (self.tables.users,
-                                     self.tables.addresses,
-                                     self.classes.Address)
-
-        class User(fixtures.ComparableEntity):
-
-            @validates('name')
-            def validate_name(self, key, name):
-                assert name != 'fred'
-                return name + ' modified'
-
-            @validates('addresses')
-            def validate_address(self, key, ad):
-                assert '@' in ad.email_address
-                return ad
-
-            def simple_function(self, key, value):
-                return key, value
-
-        u_m = mapper(User,
-                      users,
-                      properties={'addresses':relationship(Address)})
-        mapper(Address, addresses)
-
-        eq_(
-            dict((k, v[0].__name__) for k, v in list(u_m.validators.items())),
-            {'name':'validate_name',
-            'addresses':'validate_address'}
-        )
-
-    def test_validator_w_removes(self):
-        users, addresses, Address = (self.tables.users,
-                                     self.tables.addresses,
-                                     self.classes.Address)
-        canary = []
-        class User(fixtures.ComparableEntity):
-
-            @validates('name', include_removes=True)
-            def validate_name(self, key, item, remove):
-                canary.append((key, item, remove))
-                return item
-
-            @validates('addresses', include_removes=True)
-            def validate_address(self, key, item, remove):
-                canary.append((key, item, remove))
-                return item
-
-        mapper(User,
-                      users,
-                      properties={'addresses':relationship(Address)})
-        mapper(Address, addresses)
-
-        u1 = User()
-        u1.name = "ed"
-        u1.name = "mary"
-        del u1.name
-        a1, a2, a3 = Address(), Address(), Address()
-        u1.addresses.append(a1)
-        u1.addresses.remove(a1)
-        u1.addresses = [a1, a2]
-        u1.addresses = [a2, a3]
-
-        eq_(canary, [
-                ('name', 'ed', False),
-                ('name', 'mary', False),
-                ('name', 'mary', True),
-                # append a1
-                ('addresses', a1, False),
-                # remove a1
-                ('addresses', a1, True),
-                # set to [a1, a2] - this is two appends
-                ('addresses', a1, False), ('addresses', a2, False),
-                # set to [a2, a3] - this is a remove of a1,
-                # append of a3.  the appends are first.
-                ('addresses', a3, False),
-                ('addresses', a1, True),
-            ]
-        )
-
 class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL):
     def test_kwarg_accepted(self):
         users, Address = self.tables.users, self.classes.Address
diff --git a/test/orm/test_validators.py b/test/orm/test_validators.py
new file mode 100644 (file)
index 0000000..417554f
--- /dev/null
@@ -0,0 +1,281 @@
+from test.orm import _fixtures
+from sqlalchemy.testing import fixtures, assert_raises, eq_, ne_
+from sqlalchemy.orm import mapper, Session, validates, relationship
+from sqlalchemy.testing.mock import Mock, call
+
+
+class ValidatorTest(_fixtures.FixtureTest):
+    def test_scalar(self):
+        users = self.tables.users
+        canary = Mock()
+
+        class User(fixtures.ComparableEntity):
+            @validates('name')
+            def validate_name(self, key, name):
+                canary(key, name)
+                ne_(name, 'fred')
+                return name + ' modified'
+
+        mapper(User, users)
+        sess = Session()
+        u1 = User(name='ed')
+        eq_(u1.name, 'ed modified')
+        assert_raises(AssertionError, setattr, u1, "name", "fred")
+        eq_(u1.name, 'ed modified')
+        eq_(canary.mock_calls, [call('name', 'ed'), call('name', 'fred')])
+
+        sess.add(u1)
+        sess.commit()
+
+        eq_(
+            sess.query(User).filter_by(name='ed modified').one(),
+            User(name='ed')
+        )
+
+    def test_collection(self):
+        users, addresses, Address = (self.tables.users,
+                                self.tables.addresses,
+                                self.classes.Address)
+
+        canary = Mock()
+        class User(fixtures.ComparableEntity):
+            @validates('addresses')
+            def validate_address(self, key, ad):
+                canary(key, ad)
+                assert '@' in ad.email_address
+                return ad
+
+        mapper(User, users, properties={
+                'addresses': relationship(Address)}
+        )
+        mapper(Address, addresses)
+        sess = Session()
+        u1 = User(name='edward')
+        a0 = Address(email_address='noemail')
+        assert_raises(AssertionError, u1.addresses.append, a0)
+        a1 = Address(id=15, email_address='foo@bar.com')
+        u1.addresses.append(a1)
+        eq_(canary.mock_calls, [call('addresses', a0), call('addresses', a1)])
+        sess.add(u1)
+        sess.commit()
+
+        eq_(
+            sess.query(User).filter_by(name='edward').one(),
+            User(name='edward', addresses=[Address(email_address='foo@bar.com')])
+        )
+
+    def test_validators_dict(self):
+        users, addresses, Address = (self.tables.users,
+                                     self.tables.addresses,
+                                     self.classes.Address)
+
+        class User(fixtures.ComparableEntity):
+
+            @validates('name')
+            def validate_name(self, key, name):
+                ne_(name, 'fred')
+                return name + ' modified'
+
+            @validates('addresses')
+            def validate_address(self, key, ad):
+                assert '@' in ad.email_address
+                return ad
+
+            def simple_function(self, key, value):
+                return key, value
+
+        u_m = mapper(User, users, properties={
+                'addresses': relationship(Address)
+            }
+        )
+        mapper(Address, addresses)
+
+        eq_(
+            dict((k, v[0].__name__) for k, v in list(u_m.validators.items())),
+            {'name': 'validate_name',
+            'addresses': 'validate_address'}
+        )
+
+    def test_validator_w_removes(self):
+        users, addresses, Address = (self.tables.users,
+                                     self.tables.addresses,
+                                     self.classes.Address)
+        canary = Mock()
+        class User(fixtures.ComparableEntity):
+
+            @validates('name', include_removes=True)
+            def validate_name(self, key, item, remove):
+                canary(key, item, remove)
+                return item
+
+            @validates('addresses', include_removes=True)
+            def validate_address(self, key, item, remove):
+                canary(key, item, remove)
+                return item
+
+        mapper(User, users, properties={
+                'addresses': relationship(Address)
+            })
+        mapper(Address, addresses)
+
+        u1 = User()
+        u1.name = "ed"
+        u1.name = "mary"
+        del u1.name
+        a1, a2, a3 = Address(), Address(), Address()
+        u1.addresses.append(a1)
+        u1.addresses.remove(a1)
+        u1.addresses = [a1, a2]
+        u1.addresses = [a2, a3]
+
+        eq_(canary.mock_calls, [
+                call('name', 'ed', False),
+                call('name', 'mary', False),
+                call('name', 'mary', True),
+                # append a1
+                call('addresses', a1, False),
+                # remove a1
+                call('addresses', a1, True),
+                # set to [a1, a2] - this is two appends
+                call('addresses', a1, False), call('addresses', a2, False),
+                # set to [a2, a3] - this is a remove of a1,
+                # append of a3.  the appends are first.
+                call('addresses', a3, False),
+                call('addresses', a1, True),
+            ]
+        )
+
+    def test_validator_wo_backrefs_wo_removes(self):
+        self._test_validator_backrefs(False, False)
+
+    def test_validator_wo_backrefs_w_removes(self):
+        self._test_validator_backrefs(False, True)
+
+    def test_validator_w_backrefs_wo_removes(self):
+        self._test_validator_backrefs(True, False)
+
+    def test_validator_w_backrefs_w_removes(self):
+        self._test_validator_backrefs(True, True)
+
+    def _test_validator_backrefs(self, include_backrefs, include_removes):
+        users, addresses = (self.tables.users,
+                                     self.tables.addresses)
+        canary = Mock()
+        class User(fixtures.ComparableEntity):
+
+            if include_removes:
+                @validates('addresses', include_removes=True,
+                                        include_backrefs=include_backrefs)
+                def validate_address(self, key, item, remove):
+                    canary(key, item, remove)
+                    return item
+            else:
+                @validates('addresses', include_removes=False,
+                                        include_backrefs=include_backrefs)
+                def validate_address(self, key, item):
+                    canary(key, item)
+                    return item
+
+        class Address(fixtures.ComparableEntity):
+            if include_removes:
+                @validates('user', include_backrefs=include_backrefs,
+                                            include_removes=True)
+                def validate_user(self, key, item, remove):
+                    canary(key, item, remove)
+                    return item
+            else:
+                @validates('user', include_backrefs=include_backrefs)
+                def validate_user(self, key, item):
+                    canary(key, item)
+                    return item
+
+        mapper(User, users, properties={
+                'addresses': relationship(Address, backref="user")
+            })
+        mapper(Address, addresses)
+
+        u1 = User()
+        u2 = User()
+        a1, a2 = Address(), Address()
+
+        # 3 append/set, two removes
+        u1.addresses.append(a1)
+        u1.addresses.append(a2)
+        a2.user = u2
+        del a1.user
+        u2.addresses.remove(a2)
+
+        # copy, so that generation of the
+        # comparisons don't get caught
+        calls = list(canary.mock_calls)
+
+        if include_backrefs:
+            if include_removes:
+                eq_(calls,
+                    [
+                        # append #1
+                        call('addresses', Address(), False),
+
+                        # backref for append
+                        call('user', User(addresses=[]), False),
+
+                        # append #2
+                        call('addresses', Address(user=None), False),
+
+                        # backref for append
+                        call('user', User(addresses=[]), False),
+
+                        # assign a2.user = u2
+                        call('user', User(addresses=[]), False),
+
+                        # backref for u1.addresses.remove(a2)
+                        call('addresses', Address(user=None), True),
+
+                        # backref for u2.addresses.append(a2)
+                        call('addresses', Address(user=None), False),
+
+                        # del a1.user
+                        call('user', User(addresses=[]), True),
+
+                        # backref for u1.addresses.remove(a1)
+                        call('addresses', Address(), True),
+
+                        # u2.addresses.remove(a2)
+                        call('addresses', Address(user=None), True),
+
+                        # backref for a2.user = None
+                        call('user', None, False)
+                    ]
+                )
+            else:
+                eq_(calls,
+                    [
+                        call('addresses', Address()),
+                        call('user', User(addresses=[])),
+                        call('addresses', Address(user=None)),
+                        call('user', User(addresses=[])),
+                        call('user', User(addresses=[])),
+                        call('addresses', Address(user=None)),
+                        call('user', None)
+                    ]
+                )
+        else:
+            if include_removes:
+                eq_(calls,
+                    [
+                         call('addresses', Address(), False),
+                         call('addresses', Address(user=None), False),
+                         call('user', User(addresses=[]), False),
+                         call('user', User(addresses=[]), True),
+                         call('addresses', Address(user=None), True)
+                    ]
+
+                )
+            else:
+                eq_(calls,
+                    [
+                        call('addresses', Address()),
+                        call('addresses', Address(user=None)),
+                        call('user', User(addresses=[]))
+                    ]
+                )