]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Sep 2005 05:17:05 +0000 (05:17 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 23 Sep 2005 05:17:05 +0000 (05:17 +0000)
lib/sqlalchemy/mapper.py

index bf137bcf87a27769ad4be8955ee053312e6afef5..dc73283450eb48a2f96d16d26c7b5770c9bca764 100644 (file)
@@ -22,7 +22,7 @@ import sqlalchemy.util as util
 import sqlalchemy.objectstore as objectstore
 import random, copy, types
 
-__ALL__ = ['eagermapper', 'eagerloader', 'lazymapper', 'lazyloader', 'eagerload', 'lazyload', 'mapper', 'lazyloader', 'lazymapper', 'clear_mappers', 'objectstore', 'sql']
+__ALL__ = ['eagermapper', 'eagerloader', 'lazymapper', 'lazyloader', 'eagerload', 'lazyload', 'assignmapper', 'mapper', 'lazyloader', 'lazymapper', 'clear_mappers', 'objectstore', 'sql']
 
 def relation(*args, **params):
     if isinstance(args[0], Mapper):
@@ -36,19 +36,36 @@ def relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin
     else:
         return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **options)
     
-def relation_mapper(class_, table, secondary = None, primaryjoin = None, secondaryjoin = None, primarytable = None, properties = None, lazy = True, foreignkey = None, primary_keys = None, **options):
+def relation_mapper(class_, table = None, secondary = None, primaryjoin = None, secondaryjoin = None, primarytable = None, properties = None, lazy = True, foreignkey = None, primary_keys = None, **options):
     return relation_loader(mapper(class_, table, primarytable=primarytable, properties=properties, primary_keys=primary_keys, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, foreignkey = foreignkey, **options)
 
+class assignmapper(object):
+    def __init__(self, table, **kwargs):
+        self.table = table
+        self.kwargs = kwargs
+        
+    def __get__(self, instance, owner):
+        if not hasattr(self, 'mapper'):
+            self.mapper = mapper(owner, self.table, **self.kwargs)
+            print "HI"
+            self.mapper._init_class()
+            if self.mapper.class_ is not owner:
+                raise "no match " + repr(self.mapper.class_) + " " + repr(owner)
+            if not hasattr(owner, 'c'):
+                raise "no c"
+        return self.mapper
     
-# TODO: where do we want to register these mappers, register them against their classes/objects etc
 _mappers = {}
-def mapper(*args, **params):
-    hashkey = mapper_hash_key(*args, **params)
+def mapper(class_, table = None, *args, **params):
+    if table is None:
+        return class_mapper(class_)
+            
+    hashkey = mapper_hash_key(class_, table, *args, **params)
     #print "HASHKEY: " + hashkey
     try:
         return _mappers[hashkey]
     except KeyError:
-        m = Mapper(hashkey, *args, **params)
+        m = Mapper(hashkey, class_, table, *args, **params)
         return _mappers.setdefault(hashkey, m)
 
 def clear_mappers():
@@ -64,23 +81,53 @@ def object_mapper(object):
     try:
         return _mappers[object._mapper]
     except AttributeError:
-        try:
-            return _mappers[object.__class__._mapper]
-        except AttributeError:
-            raise "Object " + object.__class__.__name__ + "/" + repr(id(object)) + " has no mapper specified"
+        return class_mapper(object.__class__)
+
+def class_mapper(class_):
+    try:
+        return _mappers[class_._mapper]
+    except KeyError:
+        pass
+    except AttributeError:
+        pass
+        raise "Class '%s' has no mapper associated with it" % class_.__name__
         
 class Mapper(object):
-    def __init__(self, hashkey, class_, table, primarytable = None, scope = "thread", properties = None, primary_keys = None, inherits = None, **kwargs):
+    def __init__(self, 
+                hashkey, 
+                class_, 
+                table, 
+                primarytable = None, 
+                scope = "thread", 
+                properties = None, 
+                primary_keys = None, 
+                is_primary = False, 
+                inherits = None, 
+                inherit_condition = None, 
+                **kwargs):
+                
         self.hashkey = hashkey
         self.class_ = class_
         self.scope = scope
+        self.is_primary = is_primary
+        
+        if not issubclass(class_, object):
+            raise "Class '%s' is not a new-style class" % class_.__name__
+
         if inherits is not None:
-            table = table.join(inherits.table)
+            # TODO: determine inherit_condition (make JOIN do natural joins)
+            primarytable = inherits.primarytable
+            table = sql.join(table, inherits.table, inherit_condition)
+            
         self.table = table
+            
+        # locate all tables contained within the "table" passed in, which
+        # may be a join or other construct
         tf = TableFinder()
         self.table.accept_visitor(tf)
         self.tables = tf.tables
-        self.primary_keys = {}
+
+        # determine "primary" table        
         if primarytable is None:
             if len(self.tables) > 1:
                 raise "table contains multiple tables - specify primary table argument to Mapper"
@@ -88,6 +135,8 @@ class Mapper(object):
         else:
             self.primarytable = primarytable
 
+        # determine primary keys, either passed in, or get them from our set of tables
+        self.primary_keys = {}
         if primary_keys is not None:
             for k in primary_keys:
                 self.primary_keys.setdefault(k.table, []).append(k)
@@ -103,7 +152,8 @@ class Mapper(object):
                 for k in t.primary_keys:
                     list.append(k)
 
-        self.columns = self.table.columns
+        # make table columns addressable via the mapper
+        self.columns = util.OrderedProperties()
         self.c = self.columns
         
         # object attribute names mapped to MapperProperty objects
@@ -114,19 +164,9 @@ class Mapper(object):
         # populating multiple object attributes
         self.columntoproperty = {}
         
-        # the original properties argument to match against similar 
-        # arguments, for caching purposes
-        self.properties = properties
-
-        if inherits is not None and inherits.properties is not None:
-            if self.properties is None:
-                self.properties = {}
-            for key in inherits.properties.keys():
-                self.properties.setdefault(key, inherits.properties[key])
-                
         # load custom properties 
-        if self.properties is not None:
-            for key, prop in self.properties.iteritems():
+        if properties is not None:
+            for key, prop in properties.iteritems():
                 self.props[key] = prop
                 if isinstance(prop, ColumnProperty):
                     for col in prop.columns:
@@ -136,6 +176,8 @@ class Mapper(object):
         # load properties from the main table object,
         # not overriding those set up in the 'properties' argument
         for column in self.table.columns:
+            self.columns[column.key] = column
+
             if self.columntoproperty.has_key(column.original):
                 continue
                 
@@ -153,24 +195,32 @@ class Mapper(object):
             proplist = self.columntoproperty.setdefault(column.original, [])
             proplist.append(prop)
 
-        self.init()
-
+        if inherits is not None:
+            for key, prop in inherits.props.iteritems():
+                if not self.props.has_key(key):
+                    self.props[key] = prop._copy()
+                
+                
+        if not hasattr(self.class_, '_mapper') or self.is_primary or not _mappers.has_key(self.class_._mapper):
+            self._init_class()
+        [prop.init(key, self) for key, prop in self.props.iteritems()]
+        
     engines = property(lambda s: [t.engine for t in s.tables])
 
     def __str__(self):
         return "Mapper|" + self.class_.__name__ + "|" + self.primarytable.name
     def hash_key(self):
         return self.hashkey
-        
+
+    def _init_class(self):
+        self.class_._mapper = self.hashkey
+        self.class_.c = self.c
+    
     def set_property(self, key, prop):
         self.props[key] = prop
         prop.init(key, self)
 
-    def init(self):
-        [prop.init(key, self) for key, prop in self.props.iteritems()]
-        # TODO: get some notion of "primary mapper" going so multiple mappers dont collide
-        self.class_._mapper = self.hashkey
-
+    
     def instances(self, cursor, db):
         result = util.HistoryArraySet()
         cursor = engine.ResultProxy(cursor, db, echo = db.echo)
@@ -415,6 +465,8 @@ class MapperProperty:
         """called when the mapper receives a row.  instance is the parent instance corresponding
         to the row. """
         raise NotImplementedError()
+    def _copy(self):
+        raise NotImplementedError()
     def hash_key(self):
         """describes this property and its instantiated arguments in such a way
         as to uniquely identify the concept this MapperProperty represents,within 
@@ -446,10 +498,14 @@ class ColumnProperty(MapperProperty):
     def hash_key(self):
         return "ColumnProperty(%s)" % repr([hash_key(c) for c in self.columns])
 
+    def _copy(self):
+        return ColumnProperty(*self.columns)
+        
     def init(self, key, parent):
         self.key = key
         # establish a SmartProperty property manager on the object for this key
         if not hasattr(parent.class_, key):
+            #print "regiser col on class %s key %s" % (parent.class_.__name__, key)
             objectstore.uow().register_attribute(parent.class_, key, uselist = False)
 
     def execute(self, instance, row, identitykey, imap, isnew):
@@ -472,11 +528,18 @@ class PropertyLoader(MapperProperty):
         self.foreignkey = foreignkey
         self.private = private
         self._hash_key = "%s(%s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(mapper), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist))
-            
+
+
+    def _copy(self):
+        return self.__class__(self.mapper, self.secondary, self.primaryjoin, self.secondaryjoin, self.foreignkey, self.uselist, self.private)
+        
     def hash_key(self):
         return self._hash_key
 
     def init(self, key, parent):
+        if isinstance(self.mapper, str):
+            self.mapper = object_mapper(self.mapper)
+            
         self.key = key
         self.parent = parent
         
@@ -512,6 +575,7 @@ class PropertyLoader(MapperProperty):
         (self.lazywhere, self.lazybinds) = create_lazy_clause(self.parent.table, self.primaryjoin, self.secondaryjoin)
                 
         if not hasattr(parent.class_, key):
+            #print "regiser list col on class %s key %s" % (parent.class_.__name__, key)
             objectstore.uow().register_attribute(parent.class_, key, uselist = self.uselist)
 
     class FindDependent(sql.ClauseVisitor):
@@ -649,10 +713,10 @@ class PropertyLoader(MapperProperty):
                         self.primaryjoin.accept_visitor(setter)
                     uowcommit.register_deleted_list(childlist)
                 if len(updates):
-                    parameters = {}
+                    values = {}
                     for bind in self.lazybinds.values():
-                        parameters[bind.shortname] = None
-                    statement = self.target.update(self.lazywhere, parameters = parameters)
+                        values[bind.shortname] = None
+                    statement = self.target.update(self.lazywhere, values = values)
                     statement.execute(*updates)
             else:
                 for obj in deplist:
@@ -764,7 +828,7 @@ class EagerLoader(PropertyLoader):
         [self.to_alias.append(f) for f in self.primaryjoin._get_from_objects()]
         if self.secondaryjoin is not None:
             [self.to_alias.append(f) for f in self.secondaryjoin._get_from_objects()]
-        del self.to_alias[parent.table]
+        del self.to_alias[parent.primarytable]
 
     def setup(self, key, statement, **options):
         """add a left outer join to the statement thats being constructed"""
@@ -790,7 +854,11 @@ class EagerLoader(PropertyLoader):
 
         statement.append_from(statement._outerjoin)
         statement.append_column(self.target)
+        print "coming in, mapper is " + str(self.mapper)
         for key, value in self.mapper.props.iteritems():
+            print "setup " + key
+            if value is self:
+                raise "wha?"
             value.setup(key, statement)
 
     def execute(self, instance, row, identitykey, imap, isnew):