]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Jul 2005 03:16:24 +0000 (03:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Jul 2005 03:16:24 +0000 (03:16 +0000)
lib/sqlalchemy/mapper.py

index 061fafbb7892ac3b2019197e767cc8c76101aa2a..f3b1f4ec70c1dc3b09b8b65f6b7d1b6754c0c8ad 100644 (file)
@@ -32,17 +32,22 @@ import weakref, random, copy
 
 __ALL__ = ['eagermapper', 'eagerloader', 'lazymapper', 'lazyloader', 'eagerload', 'lazyload', 'mapper', 'lazyloader', 'lazymapper', 'identitymap', 'globalidentity']
 
-def lazymapper(class_, selectable, whereclause, table = None, properties = None, **options):
-    return lazyloader(mapper(class_, selectable, table = table, properties = properties, isroot = False), whereclause, **options)
-    
-def eagermapper(class_, selectable, whereclause, table = None, properties = None, **options):
-    return eagerloader(mapper(class_, selectable, table = table, properties = properties, isroot = False), whereclause, **options)
-
-def eagerloader(mapper, whereclause, **options):
-    return EagerLoader(mapper, whereclause, **options)
 
-def lazyloader(mapper, whereclause, **options):
-    return LazyLoader(mapper, whereclause, **options)
+def relation(*args, **params):
+    #multimethod poverty
+    if type(args[0]) == Mapper:
+        return relation_loader(*args, **params)
+    else:
+        return relation_mapper(*args, **params)
+
+def relation_loader(mapper, whereclause, lazy = True, **options):
+    if lazy:
+        return LazyLoader(mapper, whereclause, **options)
+    else:
+        return EagerLoader(mapper, whereclause, **options)
+    
+def relation_mapper(class_, selectable, whereclause, table = None, properties = None, lazy = True, **options):
+    return relation_loader(mapper(class_, selectable, table = table, properties = properties, isroot = False), whereclause, lazy = lazy, **options)
 
 def mapper(class_, selectable, table = None, properties = None, identitymap = None, use_smart_properties = True, isroot = True):
     return Mapper(class_, selectable, table = table, properties = properties, identitymap = identitymap, use_smart_properties = use_smart_properties, isroot = isroot)
@@ -219,7 +224,7 @@ class Mapper(object):
         else:
             instance = self.identitymap[identitykey]
         instance.dirty = False
-        
+
         # call further mapper properties on the row, to pull further 
         # instances from the row and possibly populate this item.
         for key, prop in self.props.iteritems():
@@ -230,7 +235,7 @@ class Mapper(object):
         # that is associated with that list
         try:
             imap = localmap[id(result)]
-        except:
+        except KeyError:
             imap = localmap.setdefault(id(result), IdentityMap())
         if not imap.has_key(identitykey):
             imap[identitykey] = instance
@@ -242,7 +247,7 @@ class MapperOption:
         raise NotImplementedError()
         
 class MapperProperty:
-    def execute(self, instance, row, isduplicate):
+    def execute(self, instance, row, identitykey, localmap, isduplicate):
         """called when the mapper receives a row.  instance is the parent instance corresponding
         to the row. """
         raise NotImplementedError()
@@ -301,8 +306,26 @@ class PropertyLoader(MapperProperty):
         self.mapper.delete()
         
 class LazyLoader(PropertyLoader):
-    pass
-    
+    def setup(self, key, primarytable, statement, **options):
+        self.lazywhere = self.whereclause.copy_structure()
+        li = LazyIzer(primarytable)
+        self.lazywhere.accept_visitor(li)
+        self.binds = li.binds
+
+    def init(self, key, parent, root):
+        PropertyLoader.init(self, key, parent, root)
+        setattr(parent.class_, key, SmartProperty(key).property())
+
+    def execute(self, instance, row, identitykey, localmap, isduplicate):
+        if not isduplicate:
+            def load():
+                m = {}
+                for key, value in self.binds.iteritems():
+                    m[key] = row[key]
+                return self.mapper.select(**m)
+
+            setattr(instance, self.key, load)
+        
 class EagerLoader(PropertyLoader):
     def setup(self, key, primarytable, statement, **options):
         """add a left outer join to the statement thats being constructed"""
@@ -369,21 +392,45 @@ class Aliasizer(sql.ClauseVisitor):
         if isinstance(binary.right, schema.Column) and binary.right.table == self.table:
             binary.right = self.alias.c[binary.right.name]
 
+class LazyIzer(sql.ClauseVisitor):
+    def __init__(self, table):
+        self.table = table
+        self.binds = {}
+        
+    def visit_binary(self, binary):
+        if isinstance(binary.left, schema.Column) and binary.left.table == self.table:
+            binary.left = self.binds.setdefault(binary.left.name,
+                    sql.BindParamClause(self.table.name + "_" + binary.left.name, None, shortname = binary.left.name))
+
+        if isinstance(binary.right, schema.Column) and binary.right.table == self.table:
+            binary.right = self.binds.setdefault(binary.right.name,
+                    sql.BindParamClause(self.table.name + "_" + binary.right.name, None, shortname = binary.left.name))
+    
+
+
 class SmartProperty(object):
     def __init__(self, key):
         self.key = key
 
     def property(self):
         def set_prop(s, value):
+            print "hi setting is " + repr(value)
+            raise "hi"
             s.__dict__[self.key] = value
             s.dirty = True
         def del_prop(s):
             del s.__dict__[self.key]
             s.dirty = True
         def get_prop(s):
+            v = s.__dict__[self.key]
+            # TODO: this sucks a little
+            print "hi thing is " + repr(v)
+            if isinstance(v, types.FunctionType):
+                s.__dict__[self.key] = v()
             return s.__dict__[self.key]
         return property(get_prop, set_prop, del_prop)
-        
+
+
 class IdentityMap(dict):
     def get_key(self, row, class_, table, selectable):
         return (class_, table, tuple([row[column.label] for column in selectable.primary_keys]))