]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- result.last_inserted_ids() should return a list that is identically
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jun 2007 00:49:08 +0000 (00:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 Jun 2007 00:49:08 +0000 (00:49 +0000)
sized to the primary key constraint of the table.  values that were
"passively" created and not available via cursor.lastrowid will be None.
- sqlite: string PK column inserts dont get overwritten with OID [ticket:603]

CHANGES
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/default.py
test/sql/query.py

diff --git a/CHANGES b/CHANGES
index 1662d7624d20b6e7934bac067995fbf5c97b7c3b..58c6c48335d84b45a99de415d2321bc5da426c59 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -14,6 +14,9 @@
       to polymorphic mappers that are using a straight "outerjoin"
       clause
 - sql
+    - result.last_inserted_ids() should return a list that is identically
+      sized to the primary key constraint of the table.  values that were 
+      "passively" created and not available via cursor.lastrowid will be None.
     - long-identifier detection fixed to use > rather than >= for 
       max ident length [ticket:589]
     - fixed bug where selectable.corresponding_column(selectable.c.col)
@@ -28,9 +31,8 @@
 - sqlite
     - sqlite better handles datetime/date/time objects mixed and matched
       with various Date/Time/DateTime columns
+    - string PK column inserts dont get overwritten with OID [ticket:603] 
 
-    
->>>>>>> .r2741
 0.3.8
 - engines
     - added detach() to Connection, allows underlying DBAPI connection
index 4336296dd90fa520d9615c8b287cbcfdbc7c5415..2b6808eaca4bada54e0df74e3e6d0c4e6ad4b615 100644 (file)
@@ -271,13 +271,14 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
                 self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.fullname)
                 self.IINSERT = False
             elif self.HASIDENT:
-                if self.dialect.use_scope_identity:
-                    self.cursor.execute("SELECT scope_identity() AS lastrowid")
-                else:
-                    self.cursor.execute("SELECT @@identity AS lastrowid")
-                row = self.cursor.fetchone()
-                self._last_inserted_ids = [int(row[0])]
-                # print "LAST ROW ID", self._last_inserted_ids
+                if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
+                    if self.dialect.use_scope_identity:
+                        self.cursor.execute("SELECT scope_identity() AS lastrowid")
+                    else:
+                        self.cursor.execute("SELECT @@identity AS lastrowid")
+                    row = self.cursor.fetchone()
+                    self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
+                    # print "LAST ROW ID", self._last_inserted_ids
             self.HASIDENT = False
         super(MSSQLExecutionContext, self).post_exec()
 
index 09825bef0f740b1ff6e4014aad37c84eb2485159..91d59f1e23d89b311fb9328fc5d823b97fc7003e 100644 (file)
@@ -945,8 +945,9 @@ def descriptor():
 class MySQLExecutionContext(default.DefaultExecutionContext):
     def post_exec(self):
         if self.compiled.isinsert:
-            self._last_inserted_ids = [self.cursor.lastrowid]
-
+            if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
+                self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
+                
 class MySQLDialect(ansisql.ANSIDialect):
     def __init__(self, **kwargs):
         ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
index aaaf55697f87e87e81a0142a1c2fab55f9b6a757..8e776fb26ed7e81bc4fa17b0d5a8dea8030e6dc0 100644 (file)
@@ -148,7 +148,9 @@ def descriptor():
 class SQLiteExecutionContext(default.DefaultExecutionContext):
     def post_exec(self):
         if self.compiled.isinsert:
-            self._last_inserted_ids = [self.cursor.lastrowid]
+            if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
+                self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:]
+                
         super(SQLiteExecutionContext, self).post_exec()
         
 class SQLiteDialect(ansisql.ANSIDialect):
index 976da1a73b869f78a4bb70d3a5b6040de0f9ad7b..f295f5fd24666461e81e5c56b3cd6a32e4507acd 100644 (file)
@@ -283,7 +283,6 @@ class DefaultExecutionContext(base.ExecutionContext):
             self._lastrow_has_defaults = False
             for param in plist:
                 last_inserted_ids = []
-                need_lastrowid=False
                 # check the "default" status of each column in the table
                 for c in self.compiled.statement.table.c:
                     # check if it will be populated by a SQL clause - we'll need that
@@ -291,7 +290,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                     if c in self.compiled.inline_params:
                         self._lastrow_has_defaults = True
                         if c.primary_key:
-                            need_lastrowid = True
+                            last_inserted_ids.append(None)
                     # check if its not present at all.  see if theres a default
                     # and fire it off, and add to bind parameters.  if
                     # its a pk, add the value to our last_inserted_ids list,
@@ -306,15 +305,14 @@ class DefaultExecutionContext(base.ExecutionContext):
                             if c.primary_key:
                                 last_inserted_ids.append(param.get_processed(c.key))
                         elif c.primary_key:
-                            need_lastrowid = True
+                            last_inserted_ids.append(None)
                     # its an explicitly passed pk value - add it to
                     # our last_inserted_ids list.
                     elif c.primary_key:
                         last_inserted_ids.append(param.get_processed(c.key))
-                if need_lastrowid:
-                    self._last_inserted_ids = None
-                else:
-                    self._last_inserted_ids = last_inserted_ids
+                # TODO: we arent accounting for executemany() situations
+                # here (hard to do since lastrowid doesnt support it either)
+                self._last_inserted_ids = last_inserted_ids
                 self._last_inserted_params = param
         elif self.compiled.isupdate:
             if isinstance(self.compiled_parameters, list):
index 6e43f8779a7295713404cf1957f1fc3e94c9bc29..1c63132b59900cfb98a4370f80023b17a2b271d0 100644 (file)
@@ -44,6 +44,85 @@ class QueryTest(PersistTest):
         self.users.update(self.users.c.user_id == 7).execute(user_name = 'fred')
         print repr(self.users.select().execute().fetchall())
 
+    def test_lastrow_accessor(self):
+        """test the last_inserted_ids() and lastrow_has_id() functions"""
+
+        def insert_values(table, values):
+            """insert a row into a table, return the full list of values INSERTed including defaults
+            that fired off on the DB side.  
+            
+            detects rows that had defaults and post-fetches.
+            """
+            
+            result = table.insert().execute(**values)
+            ret = values.copy()
+
+            for col, id in zip(table.primary_key, result.last_inserted_ids()):
+                ret[col.key] = id
+
+            if result.lastrow_has_defaults():
+                criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())])
+                row = table.select(criterion).execute().fetchone()
+                ret.update(row)
+            return ret
+
+        for supported, table, values, assertvalues in [
+            (
+                {'unsupported':['sqlite']},
+                Table("t1", metadata, 
+                    Column('id', Integer, primary_key=True),
+                    Column('foo', String(30), primary_key=True)),
+                {'foo':'hi'},
+                {'id':1, 'foo':'hi'}
+            ),
+            (
+                {'unsupported':['sqlite']},
+                Table("t2", metadata, 
+                    Column('id', Integer, primary_key=True),
+                    Column('foo', String(30), primary_key=True),
+                    Column('bar', String(30), PassiveDefault('hi'))
+                ),
+                {'foo':'hi'},
+                {'id':1, 'foo':'hi', 'bar':'hi'}
+            ),
+            (
+                {'unsupported':[]},
+                Table("t3", metadata, 
+                    Column("id", String(40), primary_key=True),
+                    Column('foo', String(30), primary_key=True),
+                    Column("bar", String(30))
+                    ),
+                    {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"},
+                    {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}
+            ),
+            (
+                {'unsupported':[]},
+                Table("t4", metadata, 
+                    Column('id', Integer, primary_key=True),
+                    Column('foo', String(30), primary_key=True),
+                    Column('bar', String(30), PassiveDefault('hi'))
+                ),
+                {'foo':'hi', 'id':1},
+                {'id':1, 'foo':'hi', 'bar':'hi'}
+            ),
+            (
+                {'unsupported':[]},
+                Table("t5", metadata, 
+                    Column('id', String(10), primary_key=True),
+                    Column('bar', String(30), PassiveDefault('hi'))
+                ),
+                {'id':'id1'},
+                {'id':'id1', 'bar':'hi'},
+            ),
+        ]:
+            if testbase.db.name in supported['unsupported']:
+                continue
+            try:
+                table.create()
+                assert insert_values(table, values) == assertvalues, repr(values) + " " + repr(assertvalues)
+            finally:
+                table.drop()
+
     def testrowiteration(self):
         self.users.insert().execute(user_id = 7, user_name = 'jack')
         self.users.insert().execute(user_id = 8, user_name = 'ed')