]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- column_property(), composite_property(), and relation() now
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Sep 2008 19:51:48 +0000 (19:51 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Sep 2008 19:51:48 +0000 (19:51 +0000)
accept a single or list of AttributeExtensions using the
"extension" keyword argument.
- Added a Validator AttributeExtension, as well as a
@validates decorator which is used in a similar fashion
as @reconstructor, and marks a method as validating
one or more mapped attributes.
- removed validate_attributes example, the new methodology replaces it

CHANGES
doc/build/content/mappers.txt
examples/custom_attributes/listen_for_events.py
examples/custom_attributes/validate_attributes.py [deleted file]
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/orm/util.py
test/orm/mapper.py

diff --git a/CHANGES b/CHANGES
index 5d91ac46c647fe683e7e8d27204ef156f2c14be1..8ec9ecc33db8c55746ce66f4249d724bcdf1ec78 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -33,16 +33,22 @@ CHANGES
       clause will appear in the WHERE clause of the query as well
       since this discrimination has multiple trigger points.
 
-    - AttributeListener has been refined such that the event
+    - AttributeExtension has been refined such that the event
       is fired before the mutation actually occurs.  Addtionally,
       the append() and set() methods must now return the given value,
       which is used as the value to be used in the mutation operation.
       This allows creation of validating AttributeListeners which
       raise before the action actually occurs, and which can change
       the given value into something else before its used.
-      A new example "validate_attributes.py" shows one such recipe
-      for doing this.   AttributeListener helper functions are
-      also on the way.
+
+    - column_property(), composite_property(), and relation() now 
+      accept a single or list of AttributeExtensions using the 
+      "extension" keyword argument.
+      
+    - Added a Validator AttributeExtension, as well as a 
+      @validates decorator which is used in a similar fashion
+      as @reconstructor, and marks a method as validating
+      one or more mapped attributes.
       
     - class.someprop.in_() raises NotImplementedError pending the
       implementation of "in_" for relation [ticket:1140]
index 3e9156335b35f5396585fb2b18652ce8481fd016..bbf9ffaf550de1e7e07f58d16f43121a3ef63501 100644 (file)
@@ -139,18 +139,54 @@ Correlated subqueries may be used as well:
             )
     })
 
-#### Overriding Attribute Behavior with Synonyms {@name=overriding}
+#### Changing Attribute Behavior {@name=attributes}
 
-A common request is the ability to create custom class properties that override the behavior of setting/getting an attribute.  As of 0.4.2, the `synonym()` construct provides an easy way to do this in conjunction with a normal Python `property` constructs.  Below, we re-map the `email` column of our mapped table to a custom attribute setter/getter, mapping the actual column to the property named `_email`:
+##### Simple Validators {@name=validators}
+
+A quick way to add a "validation" routine to an attribute is to use the `@validates` decorator.  This is a shortcut for using the [docstrings_sqlalchemy.orm_Validator](rel:docstrings_sqlalchemy.orm_Validator) attribute extension with individual column or relation based attributes.   An attribute validator can raise an exception, halting the process of mutating the attribute's value, or can change the given value into something different.   Validators, like all attribute extensions, are only called by normal userland code; they are not issued when the ORM is populating the object.
 
     {python}
-    class MyAddress(object):
+    addresses_table = Table('addresses', metadata, 
+        Column('id', Integer, primary_key=True),
+        Column('email', String)
+    )
+    
+    class EmailAddress(object):
+        @validates('email')
+        def validate_email(self, key, address):
+            assert '@' in address
+            return address
+            
+    mapper(EmailAddress, addresses_table)
+        
+Validators also receive collection events, when items are added to a collection:
+
+    {python}
+    class User(object):
+        @validates('addresses')
+        def validate_address(self, key, address):
+            assert '@' in address.email
+            return address
+    
+##### Using Descriptors {@name=overriding}
+
+A more comprehensive way to produce modified behavior for an attribute is to use descriptors.   These are commonly used in Python using the `property()` function.   The standard SQLAlchemy technique for descriptors is to create a plain descriptor, and to have it read/write from a mapped attribute with a different name.  To have the descriptor named the same as a column, map the column under a different name, i.e.:
+
+    {python}
+    class EmailAddress(object):
        def _set_email(self, email):
           self._email = email
        def _get_email(self):
           return self._email
        email = property(_get_email, _set_email)
 
+    mapper(MyAddress, addresses_table, properties={
+        '_email': addresses_table.c.email
+    })
+    
+However, the approach above is not complete.  While our `EmailAddress` object will shuttle the value through the `email` descriptor and into the `_email` mapped attribute, the class level `EmailAddress.email` attribute does not have the usual expression semantics usable with `Query`.  To provide these, we instead use the `synonym()` function as follows:
+
+    {python}
     mapper(MyAddress, addresses_table, properties={
         'email': synonym('_email', map_column=True)
     })
index e980e61edcb7c08ed46b2660df39a477bac7ead1..c028e0fb48220d530b3d4be12fc6f913f4fece73 100644 (file)
@@ -1,5 +1,6 @@
 """
-Illustrates how to use AttributeExtension to listen for change events.
+Illustrates how to use AttributeExtension to listen for change events 
+across the board.
 
 """
 
diff --git a/examples/custom_attributes/validate_attributes.py b/examples/custom_attributes/validate_attributes.py
deleted file mode 100644 (file)
index 63b2529..0000000
+++ /dev/null
@@ -1,117 +0,0 @@
-"""
-Illustrates how to use AttributeExtension to create attribute validators.
-
-"""
-
-from sqlalchemy.orm.interfaces import AttributeExtension, InstrumentationManager
-
-class InstallValidators(InstrumentationManager):
-    """Searches a class for methods with a '_validates' attribute and assembles Validators."""
-    
-    def __init__(self, cls):
-        self.validators = {}
-        for k in dir(cls):
-            item = getattr(cls, k)
-            if hasattr(item, '_validates'):
-                self.validators[item._validates] = item
-                
-    def instrument_attribute(self, class_, key, inst):
-        """Add an event listener to an InstrumentedAttribute."""
-        
-        if key in self.validators:
-            inst.impl.extensions.insert(0, Validator(key, self.validators[key]))
-        return super(InstallValidators, self).instrument_attribute(class_, key, inst)
-        
-class Validator(AttributeExtension):
-    """Validates an attribute, given the key and a validation function."""
-    
-    def __init__(self, key, validator):
-        self.key = key
-        self.validator = validator
-    
-    def append(self, state, value, initiator):
-        return self.validator(state.obj(), value)
-
-    def set(self, state, value, oldvalue, initiator):
-        return self.validator(state.obj(), value)
-
-def validates(key):
-    """Mark a method as validating a named attribute."""
-    
-    def wrap(fn):
-        fn._validates = key
-        return fn
-    return wrap
-
-if __name__ == '__main__':
-
-    from sqlalchemy import *
-    from sqlalchemy.orm import *
-    from sqlalchemy.ext.declarative import declarative_base
-    import datetime
-    
-    Base = declarative_base(engine=create_engine('sqlite://', echo=True))
-    Base.__sa_instrumentation_manager__ = InstallValidators
-
-    class MyMappedClass(Base):
-        __tablename__ = "mytable"
-    
-        id = Column(Integer, primary_key=True)
-        date = Column(Date)
-        related_id = Column(Integer, ForeignKey("related.id"))
-        related = relation("Related", backref="mapped")
-
-        @validates('date')
-        def check_date(self, value):
-            if isinstance(value, str):
-                m, d, y = [int(x) for x in value.split('/')]
-                return datetime.date(y, m, d)
-            else:
-                assert isinstance(value, datetime.date)
-                return value
-        
-        @validates('related')
-        def check_related(self, value):
-            assert value.data == 'r1'
-            return value
-            
-        def __str__(self):
-            return "MyMappedClass(date=%r)" % self.date
-            
-    class Related(Base):
-        __tablename__ = "related"
-
-        id = Column(Integer, primary_key=True)
-        data = Column(String(50))
-
-        def __str__(self):
-            return "Related(data=%r)" % self.data
-    
-    Base.metadata.create_all()
-    session = sessionmaker()()
-    
-    r1 = Related(data='r1')
-    r2 = Related(data='r2')
-    m1 = MyMappedClass(date='5/2/2005', related=r1)
-    m2 = MyMappedClass(date=datetime.date(2008, 10, 15))
-    r1.mapped.append(m2)
-
-    try:
-        m1.date = "this is not a date"
-    except:
-        pass
-    assert m1.date == datetime.date(2005, 5, 2)
-    
-    try:
-        m2.related = r2
-    except:
-        pass
-    assert m2.related is r1
-    
-    session.add(m1)
-    session.commit()
-    assert session.query(MyMappedClass.date).order_by(MyMappedClass.date).all() == [
-        (datetime.date(2005, 5, 2),),
-        (datetime.date(2008, 10, 15),)
-    ]
-    
\ No newline at end of file
index da9134b112ee7520b6153ef5bbe6c7df10d55759..4496c21e4f5819607cb0acbd1b2cce8e6c471d02 100644 (file)
@@ -29,6 +29,7 @@ from sqlalchemy.orm.interfaces import (
      )
 from sqlalchemy.orm.util import (
      AliasedClass as aliased,
+     Validator,
      join,
      object_mapper,
      outerjoin,
@@ -44,7 +45,7 @@ from sqlalchemy.orm.properties import (
      SynonymProperty,
      )
 from sqlalchemy.orm import mapper as mapperlib
-from sqlalchemy.orm.mapper import reconstructor
+from sqlalchemy.orm.mapper import reconstructor, validates
 from sqlalchemy.orm import strategies
 from sqlalchemy.orm.query import AliasOption, Query
 from sqlalchemy.sql import util as sql_util
@@ -59,6 +60,7 @@ __all__ = (
     'EXT_STOP',
     'InstrumentationManager',
     'MapperExtension',
+    'Validator',
     'PropComparator',
     'Query',
     'aliased',
@@ -91,6 +93,7 @@ __all__ = (
     'synonym',
     'undefer',
     'undefer_group',
+    'validates'
     )
 
 
@@ -206,6 +209,14 @@ def relation(argument, secondary=None, **kwargs):
           a class or function that returns a new list-holding object. will be
           used in place of a plain list for storing elements.
 
+        extension
+          an [sqlalchemy.orm.interfaces#AttributeExtension] instance, 
+          or list of extensions, which will be prepended to the list of 
+          attribute listeners for the resulting descriptor placed on the class.
+          These listeners will receive append and set events before the 
+          operation proceeds, and may be used to halt (via exception throw)
+          or change the value used in the operation.
+          
         foreign_keys
           a list of columns which are to be used as "foreign key" columns.
           this parameter should be used in conjunction with explicit
@@ -396,6 +407,14 @@ def column_property(*args, **kwargs):
           attribute is first accessed on an instance.  See also
           [sqlalchemy.orm#deferred()].
 
+      extension
+        an [sqlalchemy.orm.interfaces#AttributeExtension] instance, 
+        or list of extensions, which will be prepended to the list of 
+        attribute listeners for the resulting descriptor placed on the class.
+        These listeners will receive append and set events before the 
+        operation proceeds, and may be used to halt (via exception throw)
+        or change the value used in the operation.
+
     """
 
     return ColumnProperty(*args, **kwargs)
@@ -461,6 +480,14 @@ def composite(class_, *cols, **kwargs):
       An optional instance of [sqlalchemy.orm#PropComparator] which provides
       SQL expression generation functions for this composite type.
 
+    extension
+      an [sqlalchemy.orm.interfaces#AttributeExtension] instance, 
+      or list of extensions, which will be prepended to the list of 
+      attribute listeners for the resulting descriptor placed on the class.
+      These listeners will receive append and set events before the 
+      operation proceeds, and may be used to halt (via exception throw)
+      or change the value used in the operation.
+
     """
     return CompositeProperty(class_, *cols, **kwargs)
 
index 99593e88940bed9862f8b32e52041dc27a3d4d3a..cfea61e26c4cf452747a99008ba6dd3b51acf40b 100644 (file)
@@ -133,6 +133,7 @@ class Mapper(object):
         self.column_prefix = column_prefix
         self.polymorphic_on = polymorphic_on
         self._dependency_processors = []
+        self._validators = {}
         self._clause_adapter = None
         self._requires_row_aliasing = False
         self.__inherits_equated_pairs = None
@@ -868,11 +869,13 @@ class Mapper(object):
         event_registry.add_listener('on_init', _event_on_init)
         event_registry.add_listener('on_init_failure', _event_on_init_failure)
         for key, method in util.iterate_attributes(self.class_):
-            if (isinstance(method, types.FunctionType) and
-                hasattr(method, '__sa_reconstructor__')):
-                event_registry.add_listener('on_load', method)
-                break
-
+            if isinstance(method, types.FunctionType):
+                if hasattr(method, '__sa_reconstructor__'):
+                    event_registry.add_listener('on_load', method)
+                elif hasattr(method, '__sa_validators__'):
+                    for name in method.__sa_validators__:
+                        self._validators[name] = method
+                        
         if 'reconstruct_instance' in self.extension.methods:
             def reconstruct(instance):
                 self.extension.reconstruct_instance(self, instance)
@@ -1652,7 +1655,22 @@ def reconstructor(fn):
     fn.__sa_reconstructor__ = True
     return fn
 
-
+def validates(*names):
+    """Decorate a method as a 'validator' for one or more named properties.
+    
+    Designates a method as a validator, a method which receives the 
+    name of the attribute as well as a value to be assigned, or in the
+    case of a collection to be added to the collection.  The function 
+    can then raise validation exceptions to halt the process from continuing,
+    or can modify or replace the value before proceeding.   The function
+    should otherwise return the given value.
+    
+    """
+    def wrap(fn):
+        fn.__sa_validators__ = names
+        return fn
+    return wrap
+    
 def _event_on_init(state, instance, args, kwargs):
     """Trigger mapper compilation and run init_instance hooks."""
 
index bbec299673b8fc8a871c680a754075c1e1184185..5266a682b65b5f0e3bcd37b7bba70c4501bfef26 100644 (file)
@@ -42,6 +42,7 @@ class ColumnProperty(StrategizedProperty):
         self.group = kwargs.pop('group', None)
         self.deferred = kwargs.pop('deferred', False)
         self.comparator_factory = kwargs.pop('comparator_factory', ColumnProperty.ColumnComparator)
+        self.extension = kwargs.pop('extension', None)
         util.set_creation_order(self)
         if self.deferred:
             self.strategy_class = strategies.DeferredColumnLoader
@@ -100,7 +101,7 @@ log.class_logger(ColumnProperty)
 
 class CompositeProperty(ColumnProperty):
     """subclasses ColumnProperty to provide composite type support."""
-
+    
     def __init__(self, class_, *columns, **kwargs):
         super(CompositeProperty, self).__init__(*columns, **kwargs)
         self._col_position_map = dict((c, i) for i, c in enumerate(columns))
@@ -161,6 +162,9 @@ class CompositeProperty(ColumnProperty):
         return str(self.parent.class_.__name__) + "." + self.key
 
 class SynonymProperty(MapperProperty):
+
+    extension = None
+
     def __init__(self, name, map_column=None, descriptor=None, comparator_factory=None):
         self.name = name
         self.map_column = map_column
@@ -210,6 +214,8 @@ log.class_logger(SynonymProperty)
 class ComparableProperty(MapperProperty):
     """Instruments a Python property for use in query expressions."""
 
+    extension = None
+    
     def __init__(self, comparator_factory, descriptor=None):
         self.descriptor = descriptor
         self.comparator_factory = comparator_factory
@@ -244,7 +250,7 @@ class PropertyLoader(StrategizedProperty):
         backref=None,
         _is_backref=False,
         post_update=False,
-        cascade=False,
+        cascade=False, extension=None,
         viewonly=False, lazy=True,
         collection_class=None, passive_deletes=False,
         passive_updates=True, remote_side=None,
@@ -269,6 +275,7 @@ class PropertyLoader(StrategizedProperty):
         self.comparator = PropertyLoader.Comparator(self, None)
         self.join_depth = join_depth
         self.local_remote_pairs = _local_remote_pairs
+        self.extension = extension
         self.__join_cache = {}
         self.comparator_factory = PropertyLoader.Comparator
         util.set_creation_order(self)
index f254052ec921a07c4403db3e82f8e7306e5a2253..c1d93153e861bd8691466dc8fcd0a78b8db4b5e1 100644 (file)
@@ -23,6 +23,10 @@ class DefaultColumnLoader(LoaderStrategy):
     def _register_attribute(self, compare_function, copy_function, mutable_scalars, comparator_factory, callable_=None, proxy_property=None, active_history=False):
         self.logger.info("%s register managed attribute" % self)
 
+        attribute_ext = util.to_list(self.parent_property.extension) or []
+        if self.key in self.parent._validators:
+            attribute_ext.append(mapperutil.Validator(self.key, self.parent._validators[self.key]))
+
         for mapper in self.parent.polymorphic_iterator():
             if (mapper is self.parent or not mapper.concrete) and mapper.has_property(self.key):
                 sessionlib.register_attribute(
@@ -36,6 +40,7 @@ class DefaultColumnLoader(LoaderStrategy):
                     comparator=comparator_factory(self.parent_property, mapper), 
                     parententity=mapper,
                     callable_=callable_,
+                    extension=attribute_ext,
                     proxy_property=proxy_property,
                     active_history=active_history
                     )
@@ -303,11 +308,14 @@ class AbstractRelationLoader(LoaderStrategy):
     def _register_attribute(self, class_, callable_=None, impl_class=None, **kwargs):
         self.logger.info("%s register managed %s attribute" % (self, (self.uselist and "collection" or "scalar")))
         
+        attribute_ext = util.to_list(self.parent_property.extension) or []
+        
         if self.parent_property.backref:
-            attribute_ext = self.parent_property.backref.extension
-        else:
-            attribute_ext = None
+            attribute_ext.append(self.parent_property.backref.extension)
         
+        if self.key in self.parent._validators:
+            attribute_ext.append(mapperutil.Validator(self.key, self.parent._validators[self.key]))
+            
         sessionlib.register_attribute(
             class_, 
             self.key, 
index 67a886306cdf48598bc97606f465432abe7ca6eb..3792e99be230f587713966d693cd7e70ba428d03 100644 (file)
@@ -81,7 +81,8 @@ def register_attribute(class_, key, *args, **kwargs):
         # for object-holding attributes, instrument UOWEventHandler
         # to process per-attribute cascades
         extension = util.to_list(kwargs.pop('extension', None) or [])
-        extension.insert(0, UOWEventHandler(key))
+        extension.append(UOWEventHandler(key))
+        
         kwargs['extension'] = extension
     return attributes.register_attribute(class_, key, *args, **kwargs)
 
index 7e244223b3a685bf7edd6d6cf8daf3e645527d89..453b9f510d7fca13bf8cad216ebc890dce58506d 100644 (file)
@@ -9,7 +9,7 @@ import new
 import sqlalchemy.exceptions as sa_exc
 from sqlalchemy import sql, util
 from sqlalchemy.sql import expression, util as sql_util, operators
-from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, MapperProperty
+from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, MapperProperty, AttributeExtension
 from sqlalchemy.orm import attributes, exc
 
 
@@ -46,6 +46,32 @@ class CascadeOptions(object):
                          'delete_orphan', 'refresh-expire']
              if getattr(self, x, False) is True]))
 
+
+class Validator(AttributeExtension):
+    """Runs a validation method on an attribute value to be set or appended."""
+    
+    def __init__(self, key, validator):
+        """Construct a new Validator.
+        
+            key - name of the attribute to be validated;
+            will be passed as the second argument to 
+            the validation method (the first is the object instance itself).
+            
+            validator - an function or instance method which accepts
+            three arguments; an instance (usually just 'self' for a method),
+            the key name of the attribute, and the value.  The function should
+            return the same value given, unless it wishes to modify it.
+            
+        """
+        self.key = key
+        self.validator = validator
+    
+    def append(self, state, value, initiator):
+        return self.validator(state.obj(), self.key, value)
+
+    def set(self, state, value, oldvalue, initiator):
+        return self.validator(state.obj(), self.key, value)
+    
 def polymorphic_union(table_map, typecolname, aliasname='p_union'):
     """Create a ``UNION`` statement used by a polymorphic mapper.
 
index 37cda6eec0fa8c70b72630fb6d244a1c29b107e2..7b293e56472610dcbeb63d7f4516a552a38319ea 100644 (file)
@@ -3,7 +3,7 @@
 import testenv; testenv.configure_for_tests()
 from testlib import sa, testing
 from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation, backref, create_session, class_mapper, reconstructor
+from testlib.sa.orm import mapper, relation, backref, create_session, class_mapper, reconstructor, validates
 from testlib.sa.orm import defer, deferred, synonym, attributes
 from testlib.testing import eq_
 import pickleable
@@ -1086,7 +1086,49 @@ class DeepOptionsTest(_fixtures.FixtureTest):
             x = u[0].orders[1].items[0].keywords[1]
         self.sql_count_(2, go)
 
+class ValidatorTest(_fixtures.FixtureTest):
+    @testing.resolve_artifact_names
+    def test_scalar(self):
+        class User(_base.ComparableEntity):
+            @validates('name')
+            def validate_name(self, key, name):
+                assert name != 'fred'
+                return name + ' modified'
+                
+        mapper(User, users)
+        sess = create_session()
+        u1 = User(name='ed')
+        eq_(u1.name, 'ed modified')
+        self.assertRaises(AssertionError, setattr, u1, "name", "fred")
+        eq_(u1.name, 'ed modified')
+        sess.add(u1)
+        sess.flush()
+        sess.clear()
+        eq_(sess.query(User).filter_by(name='ed modified').one(), User(name='ed'))
+        
 
+    @testing.resolve_artifact_names
+    def test_collection(self):
+        class User(_base.ComparableEntity):
+            @validates('addresses')
+            def validate_address(self, key, ad):
+                assert '@' in ad.email_address
+                return ad
+                
+        mapper(User, users, properties={'addresses':relation(Address)})
+        mapper(Address, addresses)
+        sess = create_session()
+        u1 = User(name='edward')
+        self.assertRaises(AssertionError, u1.addresses.append, Address(email_address='noemail'))
+        u1.addresses.append(Address(id=15, email_address='foo@bar.com'))
+        sess.add(u1)
+        sess.flush()
+        sess.clear()
+        eq_(
+            sess.query(User).filter_by(name='edward').one(), 
+            User(name='edward', addresses=[Address(email_address='foo@bar.com')])
+        )
+        
 class DeferredTest(_fixtures.FixtureTest):
 
     @testing.resolve_artifact_names