]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Sep 2005 06:59:30 +0000 (06:59 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Sep 2005 06:59:30 +0000 (06:59 +0000)
lib/sqlalchemy/mapper.py

index 141a990694a52880fac158dcc55e9ecb69d1454d..09e6b09623be6150aa538315e50b39f0ae2cf2ca 100644 (file)
@@ -224,14 +224,10 @@ class Mapper(object):
         for table in self.tables:
             params = {}
                 
-            needs_primaries = False
-            for primary_key in table.primary_keys:
-                if self._getattrbycolumn(obj, primary_key) is None:
-                    needs_primaries = True
-                if isinsert:
-                    for col in table.columns:
-                        params[col.key] = self._getattrbycolumn(obj, col)
-                    break
+            if isinsert:
+                for col in table.columns:
+                    params[col.key] = self._getattrbycolumn(obj, col)
+                statement = table.insert()
             else:
                 clause = sql.and_()
                 for col in table.columns:
@@ -239,23 +235,21 @@ class Mapper(object):
                         clause.clauses.append(col == self._getattrbycolumn(obj, col))
                     else:
                         params[col.key] = self._getattrbycolumn(obj, col)
-
-            if not isinsert:
                 statement = table.update(clause)
-            else:
-                statement = table.insert()
 
             statement.echo = self.echo
             statement.execute(**params)
 
-            if needs_primaries and isinstance(statement, sql.Insert):
-                primary_keys = table.engine.last_inserted_ids()
-                index = 0
+            if isinstance(statement, sql.Insert):
+                primary_key = table.engine.last_inserted_ids()[0]
+                found = False
                 for col in table.primary_keys:
-                    newid = primary_keys[index]
-                    index += 1
-                    self._setattrbycolumn(obj, col, newid)
-                #self.put(obj)
+                    if self._getattrbycolumn(obj, col) is None:
+                        if found:
+                            raise "Only one primary key per inserted row can be set via autoincrement/sequence"
+                        else:
+                            self._setattrbycolumn(obj, col, primary_key)
+                            found = True
 
     def register_dependencies(self, obj, uow):
         for prop in self.props.values():
@@ -403,8 +397,6 @@ class PropertyLoader(MapperProperty):
         self.key = key
         self.parent = parent
         
-        # TODO: if just a foreign key specified, figure out the proper "match_primaries" relationship
-        
         # if join conditions were not specified, figure them out based on primary keys
         if self.secondary is not None:
             if self.secondaryjoin is None:
@@ -413,7 +405,10 @@ class PropertyLoader(MapperProperty):
                 self.primaryjoin = self.match_primaries(parent.selectable, self.secondary)
         else:
             if self.primaryjoin is None:
-                self.primaryjoin = self.match_primaries(parent.selectable, self.target)
+                if self.foreignkey is not None and self.foreignkey.table == parent.selectable:
+                    self.primaryjoin = self.match_primaries(self.target, parent.selectable)
+                else:
+                    self.primaryjoin = self.match_primaries(parent.selectable, self.target)
         
         # if the foreign key wasnt specified and theres no assocaition table, try to figure
         # out who is dependent on who. we dont need all the foreign keys represented in the join,
@@ -448,11 +443,14 @@ class PropertyLoader(MapperProperty):
 
     def match_primaries(self, primary, secondary):
         pk = primary.primary_keys
-        if len(pk) == 1:
-            return (pk[0] == secondary.c[pk[0].name])
-        else:
-            return sql.and_([pk == secondary.c[pk.name] for pk in primary.primary_keys])
-
+        try:
+            if len(pk) == 1:
+                return (pk[0] == secondary.c[pk[0].name])
+            else:
+                return sql.and_([pk == secondary.c[pk.name] for pk in primary.primary_keys])
+        except AttributeError, e:
+            raise e.args[0] + " table: " + secondary.name
+            
     def register_dependencies(self, objlist, uow):
         if self.secondaryjoin is not None:
             # with many-to-many, set the parent as dependent on us, then the