]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 13 Oct 2005 07:02:29 +0000 (07:02 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 13 Oct 2005 07:02:29 +0000 (07:02 +0000)
lib/sqlalchemy/mapper.py

index 855c7e289aab788d6ee35084a0904d53177d8faf..ae8f70859127520ad1c54bf6741e566c523b6ec6 100644 (file)
@@ -22,7 +22,7 @@ import sqlalchemy.util as util
 import sqlalchemy.objectstore as objectstore
 import random, copy, types
 
-__ALL__ = ['eagermapper', 'eagerloader', 'lazymapper', 'lazyloader', 'eagerload', 'lazyload', 'assignmapper', 'mapper', 'lazyloader', 'lazymapper', 'clear_mappers', 'objectstore', 'sql']
+__ALL__ = ['eagermapper', 'eagerloader', 'lazymapper', 'lazyloader', 'eagerload', 'lazyload', 'assignmapper', 'mapper', 'lazyloader', 'lazymapper', 'clear_mappers', 'objectstore', 'sql', 'MapperExtension']
 
 def relation(*args, **params):
     if isinstance(args[0], type) and len(args) == 1:
@@ -32,14 +32,16 @@ def relation(*args, **params):
     else:
         return relation_mapper(*args, **params)
 
-def relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin = None, lazy = True, **options):
+def relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin = None, lazy = True, **kwargs):
     if lazy:
-        return LazyLoader(mapper, secondary, primaryjoin, secondaryjoin, **options)
+        return LazyLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs)
+    elif lazy is None:
+        return PropertyLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs)
     else:
-        return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **options)
+        return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs)
     
-def relation_mapper(class_, table = None, secondary = None, primaryjoin = None, secondaryjoin = None, primarytable = None, properties = None, lazy = True, foreignkey = None, primary_keys = None, thiscol = None, **options):
-    return relation_loader(mapper(class_, table, primarytable=primarytable, properties=properties, primary_keys=primary_keys, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, foreignkey = foreignkey, thiscol = thiscol, **options)
+def relation_mapper(class_, table=None, secondary=None, primaryjoin=None, secondaryjoin=None, **kwargs):
+    return relation_loader(mapper(class_, table, **kwargs), secondary, primaryjoin, secondaryjoin, **kwargs)
 
 class assignmapper(object):
     def __init__(self, table, **kwargs):
@@ -110,8 +112,13 @@ class Mapper(object):
                 is_primary = False, 
                 inherits = None, 
                 inherit_condition = None, 
+                extension = None,
                 **kwargs):
-                
+
+        if extension is None:
+            self.extension = MapperExtension()
+        else:
+            self.extension = extension                
         self.hashkey = hashkey
         self.class_ = class_
         self.scope = scope
@@ -174,6 +181,7 @@ class Mapper(object):
         if properties is not None:
             for key, prop in properties.iteritems():
                 if isinstance(prop, schema.Column):
+                    self.columns[key] = prop
                     prop = ColumnProperty(prop)
                 self.props[key] = prop
                 if isinstance(prop, ColumnProperty):
@@ -264,8 +272,11 @@ class Mapper(object):
             except IndexError:
                 return None
 
-    def identity_key(self, instance):
-        return objectstore.get_id_key(tuple([self._getattrbycolumn(instance, column) for column in self.primary_keys[self.table]]), self.class_, self.primarytable)
+    def identity_key(self, *primary_keys):
+        return objectstore.get_id_key(tuple(primary_keys), self.class_, self.primarytable)
+    
+    def instance_key(self, instance):
+        return self.identity_key(**[self._getattrbycolumn(instance, column) for column in self.primary_keys[self.table]])
 
     def compile(self, whereclause = None, **options):
         """works like select, except returns the SQL statement object without 
@@ -418,8 +429,6 @@ class Mapper(object):
         identitykey = self._identity_key(row)
         if objectstore.uow().has_key(identitykey):
             instance = objectstore.uow()._get(identitykey)
-            if result is not None:
-                result.append_nohistory(instance)
 
             if populate_existing:
                 isnew = not imap.has_key(identitykey)
@@ -428,6 +437,10 @@ class Mapper(object):
                 for prop in self.props.values():
                     prop.execute(instance, row, identitykey, imap, isnew)
 
+            if self.extension.append_result(self, row, imap, result, instance, populate_existing=populate_existing):
+                if result is not None:
+                    result.append_nohistory(instance)
+
             return instance
                     
         # look in result-local identitymap for it.
@@ -439,7 +452,8 @@ class Mapper(object):
                 if row[col.label] is None:
                     return None
             # plugin point
-            instance = self.class_()
+            if self.extension.create_instance(self, row, imap, self.class_) is None:
+                instance = self.class_()
             instance._mapper = self.hashkey
             instance._instance_key = identitykey
 
@@ -449,8 +463,6 @@ class Mapper(object):
             instance = imap[identitykey]
             isnew = False
 
-        if result is not None:
-            result.append_nohistory(instance)
 
         # plugin point
         
@@ -458,6 +470,10 @@ class Mapper(object):
         # instances from the row and possibly populate this item.
         for prop in self.props.values():
             prop.execute(instance, row, identitykey, imap, isnew)
+
+        if self.extension.append_result(self, row, imap, result, instance, populate_existing=populate_existing):
+            if result is not None:
+                result.append_nohistory(instance)
             
         return instance
 
@@ -533,7 +549,7 @@ class PropertyLoader(MapperProperty):
 
     """describes an object property that holds a single item or list of items that correspond to a related
     database table."""
-    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False, thiscol = None):
+    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False, thiscol = None, **kwargs):
         self.uselist = uselist
         self.argument = argument
         self.secondary = secondary
@@ -806,7 +822,10 @@ class PropertyLoader(MapperProperty):
                     associationrow[colmap[self.secondary].key] = self.parent._getattrbycolumn(obj, colmap[self.parent.primarytable])
                 elif colmap.has_key(self.target) and colmap.has_key(self.secondary):
                     associationrow[colmap[self.secondary].key] = self.mapper._getattrbycolumn(child, colmap[self.target])
-            
+
+    def execute(self, instance, row, identitykey, imap, isnew):
+        pass
+
 class LazyLoader(PropertyLoader):
     def execute(self, instance, row, identitykey, imap, isnew):
         if isnew:
@@ -962,6 +981,12 @@ class BinaryVisitor(sql.ClauseVisitor):
     def visit_binary(self, binary):
         self.func(binary)
         
+
+class MapperExtension(object):
+    def create_instance(self, mapper, row, imap, class_):
+        return None
+    def append_result(self, mapper, row, imap, result, instance, populate_existing=False):
+        return True
   
 def hash_key(obj):
     if obj is None: