]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Sep 2005 01:10:23 +0000 (01:10 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Sep 2005 01:10:23 +0000 (01:10 +0000)
lib/sqlalchemy/mapper.py
lib/sqlalchemy/util.py
test/mapper.py

index 74b72dd1bf6f74c4cf68efeecef2224143abf95e..969a93a238ec9f75278494d509805f481b87003c 100644 (file)
@@ -35,8 +35,8 @@ def relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin
     else:
         return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **options)
     
-def relation_mapper(class_, selectable, secondary = None, primaryjoin = None, secondaryjoin = None, table = None, properties = None, lazy = True, **options):
-    return relation_loader(mapper(class_, selectable, table = table, properties = properties, isroot = False, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, **options)
+def relation_mapper(class_, selectable, secondary = None, primaryjoin = None, secondaryjoin = None, table = None, properties = None, lazy = True, uselist = True, **options):
+    return relation_loader(mapper(class_, selectable, table = table, properties = properties, isroot = False, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, uselist = uselist, **options)
 
 _mappers = {}
 def mapper(*args, **params):
@@ -54,9 +54,8 @@ def lazyload(name):
     return EagerLazySwitcher(name, toeager = False)
 
 class Mapper(object):
-    def __init__(self, class_, selectable, table = None, scope = "thread", properties = None, use_smart_properties = True, isroot = True, echo = None):
+    def __init__(self, class_, selectable, table = None, scope = "thread", properties = None, isroot = True, echo = None):
         self.class_ = class_
-        self.use_smart_properties = use_smart_properties
         self.scope = scope
         self.selectable = selectable
         tf = TableFinder()
@@ -124,7 +123,6 @@ class Mapper(object):
             self.table,
             self.properties,
             self.scope,
-            self.use_smart_properties,
             self.echo
         )
 
@@ -385,32 +383,26 @@ class ColumnProperty(MapperProperty):
 
     def init(self, key, parent, root):
         self.key = key
-        if root.use_smart_properties:
-            self.use_smart = True
-            if not hasattr(parent.class_, key):
-                setattr(parent.class_, key, SmartProperty(key).property())
-        else:
-            self.use_smart = False
+        if not hasattr(parent.class_, key):
+            setattr(parent.class_, key, SmartProperty(key).property())
 
     def execute(self, instance, row, identitykey, localmap, isduplicate):
         if not isduplicate:
-            if self.use_smart:
-                clean_setattr(instance, self.key, row[self.columns[0].label])
-            else:
-                setattr(instance, self.key, row[self.columns[0].label])
+            clean_setattr(instance, self.key, row[self.columns[0].label])
 
 
 
 class PropertyLoader(MapperProperty):
     """describes an object property that holds a list of items that correspond to a related
     database table."""
-    def __init__(self, mapper, secondary, primaryjoin, secondaryjoin):
+    def __init__(self, mapper, secondary, primaryjoin, secondaryjoin, uselist = True):
+        self.uselist = uselist
         self.mapper = mapper
         self.target = self.mapper.selectable
         self.secondary = secondary
         self.primaryjoin = primaryjoin
         self.secondaryjoin = secondaryjoin
-        self._hash_key = "%s(%s, %s, %s, %s)" % (self.__class__.__name__, hash_key(mapper), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin))
+        self._hash_key = "%s(%s, %s, %s, %s, uselist=%s)" % (self.__class__.__name__, hash_key(mapper), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), repr(self.uselist))
 
     def hash_key(self):
         return self._hash_key
@@ -427,6 +419,9 @@ class PropertyLoader(MapperProperty):
         else:
             if self.primaryjoin is None:
                 self.primaryjoin = match_primaries(parent.selectable, self.target)
+                
+        if not self.uselist and not hasattr(parent.class_, key):
+            setattr(parent.class_, key, SmartProperty(key).property(usehistory = True))
 
     def save(self, obj, traverse):
         # saves child objects
@@ -436,11 +431,19 @@ class PropertyLoader(MapperProperty):
             secondary_insert = []
              
         setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary, obj)
-        childlist = getattr(obj, self.key)
-        if not isinstance(childlist, util.HistoryArraySet):
-            childlist = util.HistoryArraySet(childlist)
-            clean_setattr(obj, self.key, childlist)
-            
+        
+        
+        if self.uselist:
+            childlist = getattr(obj, self.key)
+            if not isinstance(childlist, util.HistoryArraySet):
+                childlist = util.HistoryArraySet(childlist)
+                clean_setattr(obj, self.key, childlist)
+        else:
+            childlist = GetPropHistory()
+            # this is a nasty trick to communicate with a property()
+            setattr(obj, self.key, childlist)
+            childlist = childlist.history
+
         for child in childlist.deleted_items():
             setter.child = child
             setter.associationrow = {}
@@ -452,6 +455,7 @@ class PropertyLoader(MapperProperty):
                 secondary_delete.append(setter.associationrow)
                 
         for child in childlist.added_items():
+            print "yup " + repr(child)
             setter.child = child
             setter.associationrow = {}
             self.primaryjoin.accept_visitor(setter)
@@ -515,8 +519,16 @@ class LazyLoadInstance(object):
         # quickly, so an object with a lazyloader still cant really be serialized
         self.mapper = lazyloader.mapper
         self.lazywhere = lazyloader.lazywhere
+        self.uselist = lazyloader.uselist
     def __call__(self):
-        return self.mapper.select(self.lazywhere, **self.params)
+        result = self.mapper.select(self.lazywhere, **self.params)
+        if self.uselist:
+            return result
+        else:
+            if len(result):
+                return result[0]
+            else:
+                return None
 
 class EagerLoader(PropertyLoader):
     """loads related objects inline with a parent query."""
@@ -561,14 +573,19 @@ class EagerLoader(PropertyLoader):
     def execute(self, instance, row, identitykey, localmap, isduplicate):
         """receive a row.  tell our mapper to look for a new object instance in the row, and attach
         it to a list on the parent instance."""
-        if not isduplicate:
+        if not self.uselist:
+            result_list = []
+        elif not isduplicate:
             result_list = util.HistoryArraySet()
             clean_setattr(instance, self.key, result_list)
         else:
             result_list = getattr(instance, self.key)
 
         self.mapper._instance(row, localmap, result_list)
-
+        
+        if not self.uselist:
+            clean_setattr(instance, self.key, result_list[0])
+            
 class LazyRow(MapperProperty):
     """TODO: this will lazy-load additional properties of an object from a secondary table."""
     def __init__(self, table, whereclause, **options):
@@ -679,11 +696,29 @@ class SmartProperty(object):
     def __init__(self, key):
         self.key = key
 
-    def property(self):
+    def get_history(self, obj):
+        if not hasattr(obj, '_history'):
+            obj._history = {}
+        if not obj._history.has_key(self.key):
+            obj._history[self.key] = util.PropHistory(obj.__dict__.get(self.key, None))
+        return obj._history[self.key]
+        
+    def property(self, usehistory = False):
+        # TODO: all the history/dirty crap here is temporary, should communicate with a 
+        # thread-local unit of work
         def set_prop(s, value):
+            if usehistory:
+                hist = self.get_history(s)
+                if isinstance(value, GetPropHistory):
+                    value.history = hist
+                    return
+                hist.setattr(value, s.__dict__.get(self.key, None))
             s.__dict__[self.key] = value
             s.dirty = True
         def del_prop(s):
+            if usehistory:
+                hist = self.get_history(s)
+                hist.delattr(value)
             del s.__dict__[self.key]
             s.dirty = True
         def get_prop(s):
@@ -696,6 +731,8 @@ class SmartProperty(object):
             return s.__dict__[self.key]
         return property(get_prop, set_prop, del_prop)
 
+class GetPropHistory:pass
+        
 identity_map = util.ScopedRegistry(lambda: {})
   
 def clean_setattr(object, key, value):
@@ -714,17 +751,16 @@ def hash_key(obj):
     else:
         return obj.hash_key()
 
-def mapper_hash_key(class_, selectable, table = None, properties = None, scope = "thread", use_smart_properties = True, isroot = True, echo = None):
+def mapper_hash_key(class_, selectable, table = None, properties = None, scope = "thread", isroot = True, echo = None):
     if properties is None:
         properties = {}
     return (
-        "Mapper(%s, %s, table=%s, properties=%s, scope=%s, use_smart_properties=%s, echo=%s)" % (
+        "Mapper(%s, %s, table=%s, properties=%s, scope=%s, echo=%s)" % (
             repr(class_),
             hash_key(selectable),
             hash_key(table),
             repr(dict([(k, hash_key(p)) for k,p in properties.iteritems()])),
             scope,
-            repr(use_smart_properties),
             repr(echo)
 
         )
index e6802bbbb0e63ee9dafd785abdf1c44eb26c7a03..7013e6e6cfcb916e7394e0a08aef5777eeffa019 100644 (file)
@@ -203,6 +203,36 @@ class HistoryArraySet(UserList.UserList):
         raise NotImplementedError()
     def __iadd__(self, other):
         raise NotImplementedError()
+
+class PropHistory(object):
+    def __init__(self, current):
+        self.added = None
+        self.current = current
+        self.deleted = None
+    def setattr(self, value, current):
+        self.current = None
+        self.deleted = current
+        self.added = value
+    def delattr(self, current):
+        self.deleted = current
+    def clear_history(self):
+        if self.added is not None:
+            self.current = self.added
+    def added_items(self):
+        if self.added is not None:
+            return [self.added]
+        else:
+            return []
+    def deleted_items(self):
+        if self.deleted is not None:
+            return [self.deleted]
+        else:
+            return []
+    def unchanged_items(self):
+        if self.current is not None:
+            return [self.current]
+        else:
+            return []
         
 class ScopedRegistry(object):
     def __init__(self, createfunc):
index ba1a20e90ad1191e696fa2bf8c6de1898b9fc892..673884573b227cd2bc41b837071273596b83c9cc 100644 (file)
@@ -23,7 +23,7 @@ Closed Orderss %s
 
 class Address(object):
     def __repr__(self):
-        return "Address: " + repr(self.address_id) + " " + repr(self.user_id) + " " + repr(self.email_address)
+        return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'user_id', None)) + " " + repr(self.email_address)
 
 class Order(object):
     def __repr__(self):
@@ -310,6 +310,16 @@ class SaveTest(AssertMixin):
         u = m.select(users.c.user_id==u.foo_id)[0]
         print repr(u.__dict__)
 
+    def testonetoone(self):
+        m = mapper(User, users, properties = dict(
+            address = relation(Address, addresses, lazy = True, uselist = False)
+        ))
+        u = User()
+        u.user_name = 'one2onetester'
+        u.address = Address()
+        u.address.email_address = 'myonlyaddress@foo.com'
+        m.save(u)
+        
     def testonetomany(self):
         """test basic save of one to many."""
         m = mapper(User, users, properties = dict(