]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed INSERT statements w.r.t. primary key columns that have SQL-expression
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Oct 2007 21:28:53 +0000 (21:28 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Oct 2007 21:28:53 +0000 (21:28 +0000)
  based default generators on them; SQL expression executes inline as normal
  but will not trigger a "postfetch" condition for the column, for those DB's
  who provide it via cursor.lastrowid

CHANGES
lib/sqlalchemy/sql/compiler.py
test/sql/defaults.py

diff --git a/CHANGES b/CHANGES
index 718e7528625918fe2a5b91a1df3e641873d1de98..576aecf5d6e717885f70abf386bd5c00b70f8fdc 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -20,6 +20,11 @@ CHANGES
   
 - sqlite will reflect "DECIMAL" as a numeric column.
 
+- fixed INSERT statements w.r.t. primary key columns that have SQL-expression
+  based default generators on them; SQL expression executes inline as normal
+  but will not trigger a "postfetch" condition for the column, for those DB's
+  who provide it via cursor.lastrowid
+  
 - Renamed the Dialect attribute 'preexecute_sequences' to
   'preexecute_pk_sequences'.  An attribute proxy is in place for out-of-tree
   dialects using the old name.
index e662f8e99d8f0b03a196fbb69d331c48fddabc80..43b6fb15b5cb12809accb14c6d302847db72d496 100644 (file)
@@ -664,8 +664,7 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
                 values.append((c, value))
             elif isinstance(c, schema.Column):
                 if self.isinsert:
-                    if (c.primary_key and self.dialect.preexecute_pk_sequences
-                        and not self.inline):
+                    if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline):
                         if (((isinstance(c.default, schema.Sequence) and
                               not c.default.optional) or
                              not self.dialect.supports_pk_autoincrement) or
@@ -676,17 +675,21 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor):
                     elif isinstance(c.default, schema.ColumnDefault):
                         if isinstance(c.default.arg, sql.ClauseElement):
                             values.append((c, self.process(c.default.arg.self_group())))
-                            self.postfetch.add(c)
+                            if not c.primary_key:
+                                # dont add primary key column to postfetch
+                                self.postfetch.add(c)
                         else:
                             values.append((c, create_bind_param(c, None)))
                             self.prefetch.add(c)
                     elif isinstance(c.default, schema.PassiveDefault):
-                        self.postfetch.add(c)
+                        if not c.primary_key:
+                            self.postfetch.add(c)
                     elif isinstance(c.default, schema.Sequence):
                         proc = self.process(c.default)
                         if proc is not None:
                             values.append((c, proc))
-                            self.postfetch.add(c)
+                            if not c.primary_key:
+                                self.postfetch.add(c)
                 elif self.isupdate:
                     if isinstance(c.onupdate, schema.ColumnDefault):
                         if isinstance(c.onupdate.arg, sql.ClauseElement):
index c29ffa3b396d416ffef78be702928afac292e0ad..da23a67a21f54086bfbc3f5a966184ad6ef0a170 100644 (file)
@@ -243,6 +243,34 @@ class DefaultTest(PersistTest):
         finally:
             testbase.db.execute("drop table speedy_users", None)
 
+class PKDefaultTest(PersistTest):
+    def setUpAll(self):
+        global metadata, t1, t2
+        
+        metadata = MetaData(testbase.db)
+        
+        t2 = Table('t2', metadata, 
+            Column('nextid', Integer))
+            
+        t1 = Table('t1', metadata,
+            Column('id', Integer, primary_key=True, default=select([func.max(t2.c.nextid)]).as_scalar()),
+            Column('data', String(30)))
+    
+        metadata.create_all()
+    
+    def tearDownAll(self):
+        metadata.drop_all()
+        
+    def test_basic(self):
+        t2.insert().execute(nextid=1)
+        r = t1.insert().execute(data='hi')
+        assert r.last_inserted_ids() == [1]
+
+        t2.insert().execute(nextid=2)
+        r = t1.insert().execute(data='there')
+        assert r.last_inserted_ids() == [2]
+        
+        
 class AutoIncrementTest(PersistTest):
     def setUp(self):
         global aitable, aimeta