]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
attributes overhaul #2 - attribute manager now tracks class-level initializers strict...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Apr 2006 21:04:16 +0000 (21:04 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Apr 2006 21:04:16 +0000 (21:04 +0000)
lib/sqlalchemy/attributes.py
lib/sqlalchemy/mapping/unitofwork.py
test/attributes.py
test/mapper.py
test/masscreate.py
test/objectstore.py

index 2cac5661929cb04045c91392355f21ca9dbc2e88..d1b0738de8745ee95d5be711b7d1e6f7dc9c269e 100644 (file)
@@ -37,10 +37,20 @@ class SmartProperty(object):
     create_prop method on AttributeManger, which can be overridden to provide
     subclasses of SmartProperty.
     """
-    def __init__(self, manager, key, uselist):
+    def __init__(self, manager, key, uselist, callable_, **kwargs):
         self.manager = manager
         self.key = key
         self.uselist = uselist
+        self.callable_ = callable_
+        self.kwargs = kwargs
+    def init(self, obj, attrhist=None):
+        """creates an appropriate ManagedAttribute for the given object and establishes
+        it with the object's list of managed attributes."""
+        if self.callable_ is not None:
+            func = self.callable_(obj)
+        else:
+            func = None
+        return self.manager.create_managed_attribute(obj, self.key, self.uselist, callable_=func, attrdict=attrhist, **self.kwargs)
     def __set__(self, obj, value):
         self.manager.set_attribute(obj, self.key, value)
     def __delete__(self, obj):
@@ -300,10 +310,11 @@ class AttributeManager(object):
         upon an attribute change of value."""
         pass
         
-    def create_prop(self, class_, key, uselist, **kwargs):
+    def create_prop(self, class_, key, uselist, callable_, **kwargs):
         """creates a scalar property object, defaulting to SmartProperty, which 
         will communicate change events back to this AttributeManager."""
-        return SmartProperty(self, key, uselist)
+        return SmartProperty(self, key, uselist, callable_, **kwargs)
+        
     def create_list(self, obj, key, list_, **kwargs):
         """creates a history-aware list property, defaulting to a ListAttribute which
         is a subclass of HistoryArrayList."""
@@ -365,22 +376,19 @@ class AttributeManager(object):
         # currently a no-op since the state of the object is attached to the object itself
         pass
 
-    def create_history(self, obj, key, uselist, callable_=None, **kwargs):
-        """creates a new "history" container for a specific attribute on the given object.  
-        this can be used to override a class-level attribute with something different,
-        such as a callable. """
-        p = self.create_history_container(obj, key, uselist, callable_=callable_, **kwargs)
-        self.attribute_history(obj)[key] = p
-        return p
 
     def init_attr(self, obj):
-        """sets up the _managed_attributes dictionary on an object.  this happens anyway regardless
-        of this method being called, but saves on KeyErrors being thrown in get_history()."""
+        """sets up the _managed_attributes dictionary on an object.  this happens anyway 
+        when a particular attribute is first accessed on the object regardless
+        of this method being called, however calling this first will result in an elimination of 
+        AttributeError/KeyErrors that are thrown when get_unexec_history is called for the first
+        time for a particular key."""
         d = {}
         obj._managed_attributes = d
-        cls_managed = self.class_managed(obj.__class__)
-        for value in cls_managed.values():
-            value(obj, d).plain_init(d)
+        for value in obj.__class__.__dict__.values():
+            if not isinstance(value, SmartProperty):
+                continue
+            value.init(obj, attrhist=d).plain_init(d)
 
     def get_unexec_history(self, obj, key):
         """returns the "history" container for the given attribute on the given object.
@@ -389,31 +397,35 @@ class AttributeManager(object):
         try:
             return obj._managed_attributes[key]
         except AttributeError, ae:
-            return self.class_managed(obj.__class__)[key](obj)
+            return getattr(obj.__class__, key).init(obj)
         except KeyError, e:
-            return self.class_managed(obj.__class__)[key](obj)
+            return getattr(obj.__class__, key).init(obj)
 
     def get_history(self, obj, key, **kwargs):
-        """returns the "history" container, and calls its history() method,
-        which for a TriggeredAttribute will execute the underlying callable and return the
-        resulting ScalarAttribute or ListHistory object."""
+        """accesses the appropriate ManagedAttribute container and calls its history() method.
+        For a TriggeredAttribute this will execute the underlying callable and return the
+        resulting ScalarAttribute or ListAttribute object.  For an existing ScalarAttribute
+        or ListAttribute, just returns the container."""
         return self.get_unexec_history(obj, key).history(**kwargs)
 
     def attribute_history(self, obj):
-        """returns a dictionary of "history" containers corresponding to the given object.
+        """returns a dictionary of ManagedAttribute containers corresponding to the given object.
         this dictionary is attached to the object via the attribute '_managed_attributes'.
-        If the dictionary does not exist, it will be created."""
+        If the dictionary does not exist, it will be created.  If a 'trigger' has been placed on 
+        this object via the trigger_history() method, it will first be executed."""
         try:
             return obj._managed_attributes
         except AttributeError:
+            obj._managed_attributes = {}
             trigger = obj.__dict__.pop('_managed_trigger', None)
             if trigger:
                 trigger()
-            attr = {}
-            obj._managed_attributes = attr
-            return attr
+            return obj._managed_attributes
 
     def trigger_history(self, obj, callable):
+        """removes all ManagedAttribute instances from the given object and places the given callable
+        as an attribute-wide "trigger", which will execute upon the next attribute access, after
+        which the trigger is removed and the object re-initialized to receive new ManagedAttributes. """
         try:
             del obj._managed_attributes
         except KeyError:
@@ -440,62 +452,42 @@ class AttributeManager(object):
             except KeyError:
                 pass
         
-    def class_managed(self, class_):
-        """returns a dictionary of "history container definitions", which is attached to a 
-        class.  creates the dictionary if it doesnt exist."""
-        try:
-            attr = class_._class_managed_attributes
-        except AttributeError:
-            attr = {}
-            class_._class_managed_attributes = attr
-            class_._attribute_manager = self
-        return attr
-
     def reset_class_managed(self, class_):
-        try:
-            attr = class_._class_managed_attributes
-            for key in attr.keys():
-                delattr(class_, key)
-            del class_._class_managed_attributes
-        except AttributeError:
-            pass
+        for value in class_.__dict__.values():
+            if not isinstance(value, SmartProperty):
+                continue
+            delattr(class_, value.key)
 
     def is_class_managed(self, class_, key):
-        try:
-            return class_._class_managed_attributes.has_key(key)
-        except AttributeError:
-            return False
-            
-    def create_history_container(self, obj, key, uselist, callable_ = None, **kwargs):
-        """creates a new history container for the given attribute on the given object."""
+        return hasattr(class_, key) and isinstance(getattr(class_, key), SmartProperty)
+
+    def create_managed_attribute(self, obj, key, uselist, callable_=None, attrdict=None, **kwargs):
+        """creates a new ManagedAttribute corresponding to the given attribute key on the 
+        given object instance, and installs it in the attribute dictionary attached to the object."""
         if callable_ is not None:
-            return self.create_callable(obj, key, callable_, uselist=uselist, **kwargs)
+            prop = self.create_callable(obj, key, callable_, uselist=uselist, **kwargs)
         elif not uselist:
-            return ScalarAttribute(obj, key, **kwargs)
+            prop = ScalarAttribute(obj, key, **kwargs)
         else:
-            return self.create_list(obj, key, None, **kwargs)
-        
+            prop = self.create_list(obj, key, None, **kwargs)
+        if attrdict is None:
+            attrdict = self.attribute_history(obj)
+        attrdict[key] = prop
+        return prop
+    
+    # deprecated
+    create_history=create_managed_attribute
+    
     def register_attribute(self, class_, key, uselist, callable_=None, **kwargs):
         """registers an attribute's behavior at the class level.  This attribute
         can be scalar or list based, and also may have a callable unit that will be
-        used to create the initial value.  The definition for this attribute is 
-        wrapped up into a callable which is then stored in the classes' 
-        dictionary of "class managed" attributes.  When instances of the class 
+        used to create the initial value (i.e. a lazy loader).  The definition for this attribute is 
+        wrapped up into a callable which is then stored in the corresponding
+        SmartProperty object attached to the class.  When instances of the class 
         are created and the attribute first referenced, the callable is invoked with
-        the new object instance as an argument to create the new history container.  
+        the new object instance as an argument to create the new ManagedAttribute.  
         Extra keyword arguments can be sent which
-        will be passed along to newly created history containers."""
-        def createprop(obj, attrhist=None):
-            if callable_ is not None: 
-                func = callable_(obj)
-            else:
-                func = None
-            p = self.create_history_container(obj, key, uselist, callable_=func, **kwargs)
-            if attrhist is None:
-                attrhist = self.attribute_history(obj)
-            attrhist[key] = p
-            return p
-        
-        self.class_managed(class_)[key] = createprop
-        setattr(class_, key, self.create_prop(class_, key, uselist))
+        will be passed along to newly created ManagedAttribute."""
+        class_._attribute_manager = self
+        setattr(class_, key, self.create_prop(class_, key, uselist, callable_, **kwargs))
 
index 7c3ca8fd70eb91e2a4978faf4ba740746b5a84ff..b08836a208d03898d7bc3e3560335a3722f856e8 100644 (file)
@@ -66,8 +66,8 @@ class UOWAttributeManager(attributes.AttributeManager):
         else:
             get_session(obj).register_new(obj)
             
-    def create_prop(self, class_, key, uselist, **kwargs):
-        return UOWProperty(class_, self, key, uselist)
+    def create_prop(self, class_, key, uselist, callable_, **kwargs):
+        return UOWProperty(class_, self, key, uselist, callable_, **kwargs)
 
     def create_list(self, obj, key, list_, **kwargs):
         return UOWListElement(obj, key, list_, **kwargs)
index ca34bdfc7af05af24b90b4f0c3cc6e26bdc1919a..446ea56193a59cef10109a1abe3cbbd0c929c427 100644 (file)
@@ -131,6 +131,30 @@ class AttributesTest(PersistTest):
         
         j.port = None
         self.assert_(p.jack is None)
+
+    def testinheritance(self):
+        """tests that attributes are polymorphic"""
+        class Foo(object):pass
+        class Bar(Foo):pass
+        
+        manager = attributes.AttributeManager()
+        
+        def func1():
+            return "this is the foo attr"
+        def func2():
+            return "this is the bar attr"
+        def func3():
+            return "this is the shared attr"
+        manager.register_attribute(Foo, 'element', uselist=False, callable_=lambda o:func1)
+        manager.register_attribute(Foo, 'element2', uselist=False, callable_=lambda o:func3)
+        manager.register_attribute(Bar, 'element', uselist=False, callable_=lambda o:func2)
+        
+        x = Foo()
+        y = Bar()
+        assert x.element == 'this is the foo attr'
+        assert y.element == 'this is the bar attr'
+        assert x.element2 == 'this is the shared attr'
+        assert y.element2 == 'this is the shared attr'
         
 if __name__ == "__main__":
     unittest.main()
index 299b54f0110e17095608165bc5ab9ccffd0368f2..5c1b268eb019c6bc7c200c8811a73cf79acb4154 100644 (file)
@@ -125,12 +125,15 @@ class MapperTest(MapperSuperTest):
         self.assert_sql_count(db, go, 1)
 
     def testexpire(self):
-        m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
+        m = mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)})
         u = m.get(7)
+        assert(len(u.addresses) == 1)
         u.user_name = 'foo'
+        del u.addresses[0]
         objectstore.expire(u)
         # test plain expire
         self.assert_(u.user_name =='jack')
+        self.assert_(len(u.addresses) == 1)
         
         # we're changing the database here, so if this test fails in the middle,
         # it'll screw up the other tests which are hardcoded to 7/'jack'
index 9e0900cee547abbd370aebbc1599f509849496aa..4321c210fa7dae3ec3bdda7fb2876504c4488e58 100644 (file)
@@ -35,4 +35,4 @@ for i in range(0,130):
         u.addresses.append(u)
 
 total = time.time() - now
-print "Total time", total
\ No newline at end of file
+print "Total time", total
index 72b5e16cec37e8cc841736ea6866ccfbce26fc96..97b7e817dcbaa89f224ce3cfa8191a769d866ba1 100644 (file)
@@ -1078,7 +1078,6 @@ class SaveTest(AssertMixin):
         a.user = u
         objectstore.commit()
         print repr(u.addresses)
-        print repr(u.addresses)
         x = False
         try:
             u.addresses.append('hi')
@@ -1087,7 +1086,7 @@ class SaveTest(AssertMixin):
             pass
             
         if x:
-            self.assert_(False, "User addresses element should be read-only")
+            self.assert_(False, "User addresses element should be scalar based")
         
         objectstore.delete(u)
         objectstore.commit()