]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
got rough 'topology sort' thing to work
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 Sep 2005 06:05:18 +0000 (06:05 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 10 Sep 2005 06:05:18 +0000 (06:05 +0000)
lib/sqlalchemy/mapper.py

index a4eb5f5cee49fa1b7d81cc10038003968d957845..470ec016b947ed4ce51441a174fa5890724a3953 100644 (file)
@@ -36,8 +36,8 @@ def relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin
     else:
         return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **options)
     
-def relation_mapper(class_, selectable, secondary = None, primaryjoin = None, secondaryjoin = None, table = None, properties = None, lazy = True, uselist = True, direction = None, **options):
-    return relation_loader(mapper(class_, selectable, table = table, properties = properties, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, uselist = uselist, direction = direction, **options)
+def relation_mapper(class_, selectable, secondary = None, primaryjoin = None, secondaryjoin = None, table = None, properties = None, lazy = True, uselist = True, foreignkey = None, **options):
+    return relation_loader(mapper(class_, selectable, table = table, properties = properties, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, uselist = uselist, foreignkey = foreignkey, **options)
 
 _mappers = {}
 def mapper(*args, **params):
@@ -46,7 +46,7 @@ def mapper(*args, **params):
     try:
         return _mappers[hashkey]
     except KeyError:
-        m = Mapper(*args, **params)
+        m = Mapper(hashkey, *args, **params)
         return _mappers.setdefault(hashkey, m)
     
 def eagerload(name):
@@ -65,7 +65,8 @@ def object_mapper(object):
             raise "Object " + object.__class__.__name__ + "/" + repr(id(object)) + " has no mapper specified"
         
 class Mapper(object):
-    def __init__(self, class_, selectable, table = None, scope = "thread", properties = None, echo = None, **kwargs):
+    def __init__(self, hashkey, class_, selectable, table = None, scope = "thread", properties = None, echo = None, **kwargs):
+        self.hashkey = hashkey
         self.class_ = class_
         self.scope = scope
         self.selectable = selectable
@@ -126,24 +127,15 @@ class Mapper(object):
         self.init()
 
     def hash_key(self):
-        if not hasattr(self, 'hashkey'):
-            self.hashkey = mapper_hash_key(
-                self.class_,
-                self.selectable,
-                self.table,
-                self.properties,
-                self.scope,
-                self.echo
-            )
         return self.hashkey
-
+        
     def set_property(self, key, prop):
         self.props[key] = prop
         prop.init(key, self)
 
     def init(self):
         [prop.init(key, self) for key, prop in self.props.iteritems()]
-        self.class_._mapper = self.hash_key()
+        self.class_._mapper = self.hashkey
 
     def instances(self, cursor):
         result = util.HistoryArraySet()
@@ -280,7 +272,36 @@ class Mapper(object):
             for prop in self.props.values():
                 prop.save(obj, traverse)
 
+    def save_obj(self, obj):
+        for table in self.tables:
+            params = {}
+            for primary_key in table.primary_keys:
+                if self._getattrbycolumn(obj, primary_key) is None:
+                    statement = table.insert()
+                    for col in table.columns:
+                        params[col.key] = self._getattrbycolumn(obj, col)
+                    break
+            else:
+                clause = sql.and_()
+                for col in table.columns:
+                    if col.primary_key:
+                        clause.clauses.append(col == self._getattrbycolumn(obj, col))
+                    else:
+                        params[col.key] = self._getattrbycolumn(obj, col)
+                statement = table.update(clause)
+            statement.echo = self.echo
+            statement.execute(**params)
+            if isinstance(statement, sql.Insert):
+                primary_keys = table.engine.last_inserted_ids()
+                index = 0
+                for col in table.primary_keys:
+                    newid = primary_keys[index]
+                    index += 1
+                    self._setattrbycolumn(obj, col, newid)
+                self.put(obj)
+
     def register_dependencies(self, obj, uow):
+        print "hi1"
         for prop in self.props.values():
             prop.register_dependencies(obj, uow)
             
@@ -322,7 +343,7 @@ class Mapper(object):
         exists = objectstore.has_key(identitykey)
         if not exists:
             instance = self.class_()
-            instance._mapper = self.hash_key()
+            instance._mapper = self.hashkey
             for column in self.selectable.primary_keys:
                 if row[column.label] is None:
                     return None
@@ -413,17 +434,15 @@ class ColumnProperty(MapperProperty):
 class PropertyLoader(MapperProperty):
     """describes an object property that holds a list of items that correspond to a related
     database table."""
-    def __init__(self, mapper, secondary, primaryjoin, secondaryjoin, uselist = True, direction = None):
+    def __init__(self, mapper, secondary, primaryjoin, secondaryjoin, uselist = True, foreignkey = None):
         self.uselist = uselist
         self.mapper = mapper
         self.target = self.mapper.selectable
         self.secondary = secondary
         self.primaryjoin = primaryjoin
         self.secondaryjoin = secondaryjoin
-        self.direction = direction
-        self._hash_key = "%s(%s, %s, %s, %s, uselist=%s)" % (self.__class__.__name__, hash_key(mapper), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), repr(self.uselist))
-        if self.direction is not None and self.direction != 'left' and self.direction != 'right':
-            raise "direction propery must be 'left', 'right' or None"
+        self.foreignkey = foreignkey
+        self._hash_key = "%s(%s, %s, %s, %s, %s, uselist=%s)" % (self.__class__.__name__, hash_key(mapper), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(self.uselist))
             
     def hash_key(self):
         return self._hash_key
@@ -449,7 +468,7 @@ class PropertyLoader(MapperProperty):
             # else we usually will have a one-to-many where the secondary depends on the primary
             # but its possible that its reversed
             w = PropertyLoader.FindDependent()
-            w.accept_visitor(self.primaryjoin)
+            self.primaryjoin.accept_visitor(w)
             if w.dependent is None:
                 raise "cant determine primary foreign key in the join relationship....specify foreignkey=<column>"
             else:
@@ -459,14 +478,16 @@ class PropertyLoader(MapperProperty):
             setattr(parent.class_, key, SmartProperty(key).property(usehistory = True, uselist = self.uselist))
 
     class FindDependent(sql.ClauseVisitor):
+        def __init__(self):
+            self.dependent = None
         def visit_binary(self, binary):
             if binary.operator == '=':
                 if binary.left.primary_key:
-                    if self.dependent == binary.left:
+                    if self.dependent is binary.left:
                         raise "bidirectional dependency not supported...specify foreignkey"
                     self.dependent = binary.right
                 elif binary.right.primary_key:
-                    if self.dependent == binary.right:
+                    if self.dependent is binary.right:
                         raise "bidirectional dependency not supported...specify foreignkey"
                     self.dependent = binary.left
                 
@@ -479,21 +500,43 @@ class PropertyLoader(MapperProperty):
             return sql.and_([pk == secondary.c[pk.name] for pk in primary.primary_keys])
 
     def register_dependencies(self, obj, uow):
+        print 'hi2'
         if self.uselist:
             childlist = objectstore.uow().register_list_attribute(obj, self.key)
         else:
             childlist = objectstore.uow().register_attribute(obj, self.key)
 
-        if self.secondarytable is not None:
+        if self.secondaryjoin is not None:
+            print "hi6?"
             # TODO: put a "row" as a dependency into the UOW somehow
             pass
         elif self.foreignkey.table == self.target:
+            print "hi4"
+            setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary, obj)
+            def foo(obj, child):
+                setter.obj = obj
+                setter.child = child
+                setter.associationrow = {}
+                self.primaryjoin.accept_visitor(setter)
+                
             for child in childlist.added_items():
-                uow.register_dependency(obj, child)
-        elif self.foreignkey.table == self.secondary:
+                uow.register_dependency(obj, child, foo)
+        elif self.foreignkey.table == self.parent.table:
+            print "hi5"
+            setter = ForeignKeySetter(self.parent, self.mapper, self.parent.table, self.target, self.secondary, None)
+
+         #   setter = ForeignKeySetter(self.mapper, self.parent, self.target, self.parent.table, self.secondary, obj)
+            def foo(obj, child):
+                print "hi7"
+                setter.obj = obj
+                setter.child = child
+                setter.associationrow = {}
+                self.primaryjoin.accept_visitor(setter)
+
             for child in childlist.added_items():
-                uow.register_dependency(child, obj)
-        
+                uow.register_dependency(child, obj, foo)
+        else:
+            raise " no foreign key ?"
 
     def save(self, obj, traverse):
         # saves child objects