]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [feature] Added new flag to @validates
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 11 Apr 2012 17:03:52 +0000 (13:03 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 11 Apr 2012 17:03:52 +0000 (13:03 -0400)
include_removes.  When True, collection
remove and attribute del events
will also be sent to the validation function,
which accepts an additional argument
"is_remove" when this flag is used.

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
test/orm/test_mapper.py

diff --git a/CHANGES b/CHANGES
index f4d46f67329a0af20174215341ce468c865c42cb..161668e82c717da8e2e52f34d0a9edba9ed4acd0 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -12,6 +12,13 @@ CHANGES
     directives in statements.  Courtesy
     Diana Clarke [ticket:2443]
 
+  - [feature] Added new flag to @validates
+    include_removes.  When True, collection
+    remove and attribute del events
+    will also be sent to the validation function,
+    which accepts an additional argument
+    "is_remove" when this flag is used.
+
   - [bug] Fixed bug whereby polymorphic_on
     column that's not otherwise mapped on the 
     class would be incorrectly included
index e96b7549a9da892f99d658133f3bd1c3b87e3a66..afabac05a75708fad7804576ff4033e1f3b84c50 100644 (file)
@@ -678,9 +678,10 @@ class Mapper(object):
                     self._reconstructor = method
                     event.listen(manager, 'load', _event_on_load, raw=True)
                 elif hasattr(method, '__sa_validators__'):
+                    include_removes = getattr(method, "__sa_include_removes__", False)
                     for name in method.__sa_validators__:
                         self.validators = self.validators.union(
-                            {name : method}
+                            {name : (method, include_removes)}
                         )
 
         manager.info[_INSTRUMENTOR] = self
@@ -2291,7 +2292,7 @@ def reconstructor(fn):
     fn.__sa_reconstructor__ = True
     return fn
 
-def validates(*names):
+def validates(*names, **kw):
     """Decorate a method as a 'validator' for one or more named properties.
 
     Designates a method as a validator, a method which receives the
@@ -2307,9 +2308,16 @@ def validates(*names):
     an assertion to avoid recursion overflows.  This is a reentrant
     condition which is not supported.
 
+    :param \*names: list of attribute names to be validated.
+    :param include_removes: if True, "remove" events will be 
+     sent as well - the validation function must accept an additional
+     argument "is_remove" which will be a boolean.  New in 0.7.7.
+
     """
+    include_removes = kw.pop('include_removes', False)
     def wrap(fn):
         fn.__sa_validators__ = names
+        fn.__sa_include_removes__ = include_removes
         return fn
     return wrap
 
index 5f4b182d08e2f8b19fb53b32361d6b475cd28552..37980e11140e1f88de0c356ac63089085ae95ade 100644 (file)
@@ -45,11 +45,11 @@ def _register_attribute(strategy, mapper, useobject,
         listen_hooks.append(single_parent_validator)
 
     if prop.key in prop.parent.validators:
+        fn, include_removes = prop.parent.validators[prop.key]
         listen_hooks.append(
             lambda desc, prop: mapperutil._validator_events(desc, 
-                                prop.key, 
-                                prop.parent.validators[prop.key])
-        )
+                                prop.key, fn, include_removes)
+            )
 
     if useobject:
         listen_hooks.append(unitofwork.track_cascade_events)
index 0c5f203a72c84283638606699209e8150ea08e38..197c0c4c15925982845415966da5f8099c1278e9 100644 (file)
@@ -68,24 +68,36 @@ class CascadeOptions(frozenset):
             ",".join([x for x in sorted(self)])
         )
 
-def _validator_events(desc, key, validator):
+def _validator_events(desc, key, validator, include_removes):
     """Runs a validation method on an attribute value to be set or appended."""
 
-    def append(state, value, initiator):
-        return validator(state.obj(), key, value)
+    if include_removes:
+        def append(state, value, initiator):
+            return validator(state.obj(), key, value, False)
 
-    def set_(state, value, oldvalue, initiator):
-        return validator(state.obj(), key, value)
+        def set_(state, value, oldvalue, initiator):
+            return validator(state.obj(), key, value, False)
+
+        def remove(state, value, initiator):
+            validator(state.obj(), key, value, True)
+    else:
+        def append(state, value, initiator):
+            return validator(state.obj(), key, value)
+
+        def set_(state, value, oldvalue, initiator):
+            return validator(state.obj(), key, value)
 
     event.listen(desc, 'append', append, raw=True, retval=True)
     event.listen(desc, 'set', set_, raw=True, retval=True)
+    if include_removes:
+        event.listen(desc, "remove", remove, raw=True, retval=True)
 
 def polymorphic_union(table_map, typecolname, aliasname='p_union', cast_nulls=True):
     """Create a ``UNION`` statement used by a polymorphic mapper.
 
     See  :ref:`concrete_inheritance` for an example of how
     this is used.
-    
+
     :param table_map: mapping of polymorphic identities to 
      :class:`.Table` objects.
     :param typecolname: string name of a "discriminator" column, which will be 
@@ -236,7 +248,7 @@ class AliasedClass(object):
         session.query(User, user_alias).\\
                         join((user_alias, User.id > user_alias.id)).\\
                         filter(User.name==user_alias.name)
-    
+
     The resulting object is an instance of :class:`.AliasedClass`, however
     it implements a ``__getattribute__()`` scheme which will proxy attribute
     access to that of the ORM class being aliased.  All classmethods
@@ -244,7 +256,7 @@ class AliasedClass(object):
     hybrids created with the :ref:`hybrids_toplevel` extension,
     which will receive the :class:`.AliasedClass` as the "class" argument
     when classmethods are called.
-    
+
     :param cls: ORM mapped entity which will be "wrapped" around an alias.
     :param alias: a selectable, such as an :func:`.alias` or :func:`.select`
      construct, which will be rendered in place of the mapped table of the
@@ -259,28 +271,28 @@ class AliasedClass(object):
      otherwise have a column that corresponds to one on the entity.  The 
      use case for this is when associating an entity with some derived
      selectable such as one that uses aggregate functions::
-     
+
         class UnitPrice(Base):
             __tablename__ = 'unit_price'
             ...
             unit_id = Column(Integer)
             price = Column(Numeric)
-        
+
         aggregated_unit_price = Session.query(
                                     func.sum(UnitPrice.price).label('price')
                                 ).group_by(UnitPrice.unit_id).subquery()
-                                
+
         aggregated_unit_price = aliased(UnitPrice, alias=aggregated_unit_price, adapt_on_names=True)
-    
+
      Above, functions on ``aggregated_unit_price`` which
      refer to ``.price`` will return the
      ``fund.sum(UnitPrice.price).label('price')`` column,
      as it is matched on the name "price".  Ordinarily, the "price" function wouldn't
      have any "column correspondence" to the actual ``UnitPrice.price`` column
      as it is not a proxy of the original.
-     
+
      ``adapt_on_names`` is new in 0.7.3.
-        
+
     """
     def __init__(self, cls, alias=None, name=None, adapt_on_names=False):
         self.__mapper = _class_to_mapper(cls)
@@ -447,7 +459,7 @@ class _ORMJoin(expression.Join):
 
 def join(left, right, onclause=None, isouter=False, join_to_left=True):
     """Produce an inner join between left and right clauses.
-    
+
     :func:`.orm.join` is an extension to the core join interface
     provided by :func:`.sql.expression.join()`, where the
     left and right selectables may be not only core selectable
@@ -460,7 +472,7 @@ def join(left, right, onclause=None, isouter=False, join_to_left=True):
     in whatever form it is passed, to the selectable
     passed as the left side.  If False, the onclause
     is used as is.
-    
+
     :func:`.orm.join` is not commonly needed in modern usage,
     as its functionality is encapsulated within that of the
     :meth:`.Query.join` method, which features a
@@ -468,22 +480,22 @@ def join(left, right, onclause=None, isouter=False, join_to_left=True):
     by itself.  Explicit usage of :func:`.orm.join` 
     with :class:`.Query` involves usage of the 
     :meth:`.Query.select_from` method, as in::
-    
+
         from sqlalchemy.orm import join
         session.query(User).\\
             select_from(join(User, Address, User.addresses)).\\
             filter(Address.email_address=='foo@bar.com')
-    
+
     In modern SQLAlchemy the above join can be written more 
     succinctly as::
-    
+
         session.query(User).\\
                 join(User.addresses).\\
                 filter(Address.email_address=='foo@bar.com')
 
     See :meth:`.Query.join` for information on modern usage
     of ORM level joins.
-    
+
     """
     return _ORMJoin(left, right, onclause, isouter, join_to_left)
 
index 1c5f29b71677bac3aeaf6f6bf521e1d57020024e..79ae7ff5906406bae52e5ba2908eae9f33830444 100644 (file)
@@ -1950,10 +1950,11 @@ class DeepOptionsTest(_fixtures.FixtureTest):
 class ValidatorTest(_fixtures.FixtureTest):
     def test_scalar(self):
         users = self.tables.users
-
+        canary = []
         class User(fixtures.ComparableEntity):
             @validates('name')
             def validate_name(self, key, name):
+                canary.append((key, name))
                 assert name != 'fred'
                 return name + ' modified'
 
@@ -1963,6 +1964,7 @@ class ValidatorTest(_fixtures.FixtureTest):
         eq_(u1.name, 'ed modified')
         assert_raises(AssertionError, setattr, u1, "name", "fred")
         eq_(u1.name, 'ed modified')
+        eq_(canary, [('name', 'ed'), ('name', 'fred')])
         sess.add(u1)
         sess.flush()
         sess.expunge_all()
@@ -1973,9 +1975,11 @@ class ValidatorTest(_fixtures.FixtureTest):
                                 self.tables.addresses,
                                 self.classes.Address)
 
+        canary = []
         class User(fixtures.ComparableEntity):
             @validates('addresses')
             def validate_address(self, key, ad):
+                canary.append((key, ad))
                 assert '@' in ad.email_address
                 return ad
 
@@ -1983,8 +1987,11 @@ class ValidatorTest(_fixtures.FixtureTest):
         mapper(Address, addresses)
         sess = create_session()
         u1 = User(name='edward')
-        assert_raises(AssertionError, u1.addresses.append, Address(email_address='noemail'))
-        u1.addresses.append(Address(id=15, email_address='foo@bar.com'))
+        a0 = Address(email_address='noemail')
+        assert_raises(AssertionError, u1.addresses.append, a0)
+        a1 = Address(id=15, email_address='foo@bar.com')
+        u1.addresses.append(a1)
+        eq_(canary, [('addresses', a0), ('addresses', a1)])
         sess.add(u1)
         sess.flush()
         sess.expunge_all()
@@ -2019,11 +2026,60 @@ class ValidatorTest(_fixtures.FixtureTest):
         mapper(Address, addresses)
 
         eq_(
-            dict((k, v.__name__) for k, v in u_m.validators.items()),
+            dict((k, v[0].__name__) for k, v in u_m.validators.items()),
             {'name':'validate_name', 
             'addresses':'validate_address'}
         )
 
+    def test_validator_w_removes(self):
+        users, addresses, Address = (self.tables.users,
+                                     self.tables.addresses,
+                                     self.classes.Address)
+        canary = []
+        class User(fixtures.ComparableEntity):
+
+            @validates('name', include_removes=True)
+            def validate_name(self, key, item, remove):
+                canary.append((key, item, remove))
+                return item
+
+            @validates('addresses', include_removes=True)
+            def validate_address(self, key, item, remove):
+                canary.append((key, item, remove))
+                return item
+
+        mapper(User,
+                      users,
+                      properties={'addresses':relationship(Address)})
+        mapper(Address, addresses)
+
+        u1 = User()
+        u1.name = "ed"
+        u1.name = "mary"
+        del u1.name
+        a1, a2, a3 = Address(), Address(), Address()
+        u1.addresses.append(a1)
+        u1.addresses.remove(a1)
+        u1.addresses = [a1, a2]
+        u1.addresses = [a2, a3]
+
+        eq_(canary, [
+                ('name', 'ed', False), 
+                ('name', 'mary', False), 
+                ('name', 'mary', True), 
+                # append a1
+                ('addresses', a1, False), 
+                # remove a1
+                ('addresses', a1, True), 
+                # set to [a1, a2] - this is two appends
+                ('addresses', a1, False), ('addresses', a2, False),
+                # set to [a2, a3] - this is a remove of a1,
+                # append of a3.  the appends are first.
+                ('addresses', a3, False),
+                ('addresses', a1, True), 
+            ] 
+        )
+
 class ComparatorFactoryTest(_fixtures.FixtureTest, AssertsCompiledSQL):
     def test_kwarg_accepted(self):
         users, Address = self.tables.users, self.classes.Address