]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 18:47:54 +0000 (18:47 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Oct 2005 18:47:54 +0000 (18:47 +0000)
examples/adjacencytree/tables.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py

index ac8be6e58639fe5fed93b6000e3017b7f4aa102c..f39062892b18684b280db0548ca740c7cab9851d 100644 (file)
@@ -3,8 +3,8 @@ import sqlalchemy.engine
 import os
 
 #engine = sqlalchemy.engine.create_engine('sqlite', ':memory:', {}, echo = True)
-#engine = sqlalchemy.engine.create_engine('postgres', {'database':'test', 'host':'127.0.0.1', 'user':'scott', 'password':'tiger'}, echo=True)
-engine = sqlalchemy.engine.create_engine('oracle', {'dsn':os.environ['DSN'], 'user':os.environ['USER'], 'password':os.environ['PASSWORD']}, echo=True)
+engine = sqlalchemy.engine.create_engine('postgres', {'database':'test', 'host':'127.0.0.1', 'user':'scott', 'password':'tiger'}, echo=True)
+#engine = sqlalchemy.engine.create_engine('oracle', {'dsn':os.environ['DSN'], 'user':os.environ['USER'], 'password':os.environ['PASSWORD']}, echo=True)
 
 
 """create the treenodes table.  This is ia basic adjacency list model table.
index 15b6ce6682ec042341b2eee81ef6d9c599d0b5c4..bb1f3c6aea407f496fc3ddce03a703c4e44707c8 100644 (file)
@@ -105,10 +105,9 @@ class OracleSQLEngine(ansisql.ANSISQLEngine):
         raise "not implemented"
 
     def last_inserted_ids(self):
-       return self.context.last_inserted_ids
+        return self.context.last_inserted_ids
 
     def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
-        # if a sequence was explicitly defined we do it here
         if compiled is None: return
         if getattr(compiled, "isinsert", False):
             if isinstance(parameters, list):
@@ -126,13 +125,11 @@ class OracleSQLEngine(ansisql.ANSISQLEngine):
                         cursor.execute("select %s.nextval from dual" % primary_key.sequence.name)
                         newid = cursor.fetchone()[0]
                         param[primary_key.key] = newid
-                        #if compiled.statement.parameters is not None:
-                         #   compiled.statement.parameters[primary_key.key] = bindparam(primary_key.key)
                     last_inserted_ids.append(param[primary_key.key])
                 self.context.last_inserted_ids = last_inserted_ids
 
     def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
-       pass
+        pass
 
     def _executemany(self, c, statement, parameters):
         rowcount = 0
@@ -178,11 +175,13 @@ class OracleCompiler(ansisql.ANSICompiler):
             self.strings[column] = "%s.%s" % (column.table.name, column.name)
         
     def visit_insert(self, insert):
+        """inserts are required to have the primary keys be explicitly present.
+         mapper will by default not put them in the insert statement to comply
+         with autoincrement fields that require they not be present.  so, 
+         put them all in for all primary key columns."""
         for c in insert.table.primary_keys:
             if not self.bindparams.has_key(c.key):
                 self.bindparams[c.key] = None
-                #if not insert.parameters.has_key(c.key):
-                 #   insert.parameters[c.key] = sql.bindparam(c.key)
         return ansisql.ANSICompiler.visit_insert(self, insert)
 
 class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
index 76d5248f771908b08a552558c5b53d7315d1fbab..86116f83fb918d93ab4f1f4af5f0c3dd9f14607c 100644 (file)
@@ -103,6 +103,12 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
         raise "not implemented"
         
     def last_inserted_ids(self):
+        # if we used sequences or already had all values for the last inserted row,
+        # return that list
+        if self.context.last_inserted_ids is not None:
+            return self.context.last_inserted_ids
+        
+        # else we have to use lastrowid and select the most recently inserted row    
         table = self.context.last_inserted_table
         if self.context.lastrowid is not None and table is not None and len(table.primary_keys):
             row = sql.select(table.primary_keys, table.rowid_column == self.context.lastrowid).execute().fetchone()
@@ -114,13 +120,34 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
         # if a sequence was explicitly defined we do it here
         if compiled is None: return
         if getattr(compiled, "isinsert", False):
-            for primary_key in compiled.statement.table.primary_keys:
-                if primary_key.sequence is not None and not primary_key.sequence.optional and parameters[primary_key.key] is None:
-                    if echo is True or self.echo:
-                        self.log("select nextval('%s')" % primary_key.sequence.name)
-                    cursor.execute("select nextval('%s')" % primary_key.sequence.name)
-                    newid = cursor.fetchone()[0]
-                    parameters[primary_key.key] = newid
+            if isinstance(parameters, list):
+                plist = parameters
+            else:
+                plist = [parameters]
+            # inserts are usually one at a time.  but if we got a list of parameters,
+            # it will calculate last_inserted_ids for just the last row in the list. 
+            # TODO: why not make last_inserted_ids a 2D array since we have to explicitly sequence
+            # it or post-select anyway   
+            for param in plist:
+                last_inserted_ids = []
+                need_lastrowid=False
+                for primary_key in compiled.statement.table.primary_keys:
+                    if not param.has_key(primary_key.key) or param[primary_key.key] is None:
+                        if primary_key.sequence is not None and not primary_key.sequence.optional:
+                            if echo is True or self.echo:
+                                self.log("select nextval('%s')" % primary_key.sequence.name)
+                            cursor.execute("select nextval('%s')" % primary_key.sequence.name)
+                            newid = cursor.fetchone()[0]
+                            param[primary_key.key] = newid
+                            last_inserted_ids.append(param[primary_key.key])
+                        else:
+                            need_lastrowid = True
+                    else:
+                        last_inserted_ids.append(param[primary_key.key])
+                if need_lastrowid:
+                    self.context.last_inserted_ids = None
+                else:
+                    self.context.last_inserted_ids = last_inserted_ids
 
     def _executemany(self, c, statement, parameters):
         """we need accurate rowcounts for updates, inserts and deletes.  psycopg2 is not nice enough
@@ -149,11 +176,13 @@ class PGCompiler(ansisql.ANSICompiler):
         return "%(" + name + ")s"
 
     def visit_insert(self, insert):
+        """inserts are required to have the primary keys be explicitly present.
+         mapper will by default not put them in the insert statement to comply
+         with autoincrement fields that require they not be present.  so, 
+         put them all in for columns where sequence usage is defined."""
         for c in insert.table.primary_keys:
             if c.sequence is not None and not c.sequence.optional:
                 self.bindparams[c.key] = None
-                #if not insert.parameters.has_key(c.key):
-                 #   insert.parameters[c.key] = sql.bindparam(c.key)
         return ansisql.ANSICompiler.visit_insert(self, insert)
         
 class PGSchemaGenerator(ansisql.ANSISchemaGenerator):