]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Databases which rely upon postfetch of "last inserted id" to get at a
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Aug 2009 00:36:00 +0000 (00:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Aug 2009 00:36:00 +0000 (00:36 +0000)
generated sequence value (i.e. MySQL, MS-SQL) now work correctly
when there is a composite primary key where the "autoincrement" column
is not the first primary key column in the table.

06CHANGES
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/schema.py
test/sql/test_query.py

index 141f834a26e48960c334f9e4afb3325cd2dec150..bc2eb56f655e38a2663eac6d8517c82457ab6cd8 100644 (file)
--- a/06CHANGES
+++ b/06CHANGES
       (a version number check is performed).   This occurs if no end-user
       returning() was specified.
       
+    - Databases which rely upon postfetch of "last inserted id" to get at a 
+      generated sequence value (i.e. MySQL, MS-SQL) now work correctly
+      when there is a composite primary key where the "autoincrement" column
+      is not the first primary key column in the table.
       
 - engines
     - transaction isolation level may be specified with
index f21f53fd22ff81405918d72a0cb5b80abe63f109..a521932970c954bf06268bd422fae9fa97c54142 100644 (file)
@@ -823,42 +823,22 @@ class MSTypeCompiler(compiler.GenericTypeCompiler):
     def visit_SQL_VARIANT(self, type_):
         return 'SQL_VARIANT'
 
-def _has_implicit_sequence(column):
-    return column.primary_key and  \
-        column.autoincrement and \
-        isinstance(column.type, sqltypes.Integer) and \
-        not column.foreign_keys and \
-        (
-            column.default is None or 
-            (
-                isinstance(column.default, sa_schema.Sequence) and 
-                column.default.optional)
-            )
-
-def _table_sequence_column(tbl):
-    if not hasattr(tbl, '_ms_has_sequence'):
-        tbl._ms_has_sequence = None
-        for column in tbl.c:
-            if getattr(column, 'sequence', False) or _has_implicit_sequence(column):
-                tbl._ms_has_sequence = column
-                break
-    return tbl._ms_has_sequence
-
 class MSExecutionContext(default.DefaultExecutionContext):
     _enable_identity_insert = False
     _select_lastrowid = False
     _result_proxy = None
+    _lastrowid = None
     
     def pre_exec(self):
         """Activate IDENTITY_INSERT if needed."""
 
         if self.isinsert:
             tbl = self.compiled.statement.table
-            seq_column = _table_sequence_column(tbl)
-            insert_has_sequence = bool(seq_column)
+            seq_column = tbl._autoincrement_column
+            insert_has_sequence = seq_column is not None
             
             if insert_has_sequence:
-                self._enable_identity_insert = tbl._ms_has_sequence.key in self.compiled_parameters[0]
+                self._enable_identity_insert = seq_column.key in self.compiled_parameters[0]
             else:
                 self._enable_identity_insert = False
             
@@ -1094,7 +1074,7 @@ class MSDDLCompiler(compiler.DDLCompiler):
         if not column.table:
             raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
             
-        seq_col = _table_sequence_column(column.table)
+        seq_col = column.table._autoincrement_column
 
         # install a IDENTITY Sequence if we have an implicit IDENTITY column
         if seq_col is column:
@@ -1147,7 +1127,8 @@ class MSDialect(default.DefaultDialect):
     preexecute_pk_sequences = True
     
     supports_unicode_binds = True
-
+    postfetch_lastrowid = True
+    
     server_version_info = ()
     
     statement_compiler = MSSQLCompiler
index 14e73e58862eea702b28ae55354a178595c1304c..45918618d4fb6a8e2bdc717c130e19935fb43eda 100644 (file)
@@ -182,7 +182,6 @@ class DefaultDialect(base.Dialect):
 
 
 class DefaultExecutionContext(base.ExecutionContext):
-    _lastrowid = None
     
     def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None):
         self.dialect = dialect
@@ -385,12 +384,15 @@ class DefaultExecutionContext(base.ExecutionContext):
     
     def post_insert(self):
         if self.dialect.postfetch_lastrowid and \
-            self._lastrowid is None and \
             (not len(self._last_inserted_ids) or \
-                        self._last_inserted_ids[0] is None):
+                        None in self._last_inserted_ids):
+
+            table = self.compiled.statement.table
+            lastrowid = self.get_lastrowid()
+            self._last_inserted_ids = [c is table._autoincrement_column and lastrowid or v
+                for c, v in zip(table.primary_key, self._last_inserted_ids)
+            ]
             
-            self._lastrowid = self.get_lastrowid()
-        
     def last_inserted_ids(self, resultproxy):
         if not self.isinsert:
             raise exc.InvalidRequestError("Statement is not an insert() expression construct.")
@@ -398,16 +400,15 @@ class DefaultExecutionContext(base.ExecutionContext):
         if self.dialect.implicit_returning and \
                 not self.compiled.statement._returning and \
                 not resultproxy.closed:
-
+            
+            table = self.compiled.statement.table
             row = resultproxy.first()
 
             self._last_inserted_ids = [v is not None and v or row[c] 
-                for c, v in zip(self.compiled.statement.table.primary_key, self._last_inserted_ids)
+                for c, v in zip(table.primary_key, self._last_inserted_ids)
             ]
             return self._last_inserted_ids
             
-        elif self._lastrowid is not None:
-            return [self._lastrowid] + self._last_inserted_ids[1:]
         else:
             return self._last_inserted_ids
 
@@ -497,7 +498,8 @@ class DefaultExecutionContext(base.ExecutionContext):
                     compiled_parameters[c.key] = val
 
             if self.isinsert:
-                self._last_inserted_ids = [compiled_parameters.get(c.key, None) for c in self.compiled.statement.table.primary_key]
+                self._last_inserted_ids = [compiled_parameters.get(c.key, None) 
+                                            for c in self.compiled.statement.table.primary_key]
                 self._last_inserted_params = compiled_parameters
             else:
                 self._last_updated_params = compiled_parameters
index 346bf884af9d388ba0c2a182815b132899ba6062..a6961aab50a248d6b680f4ff150a6c12e3a51330 100644 (file)
@@ -280,6 +280,15 @@ class Table(SchemaItem, expression.TableClause):
         for c in pk.columns:
             c.primary_key = True
 
+    @util.memoized_property
+    def _autoincrement_column(self):
+        for col in self.primary_key:
+            if col.autoincrement and \
+                isinstance(col.type, types.Integer) and \
+                not col.foreign_keys:
+
+                return col
+
     @property
     def key(self):
         return _get_table_key(self.name, self.schema)
index 37030c94f4577ba099fec907429da10e6d937b67..979c148e4ae8d0c129f16e9887912d7452fca4a0 100644 (file)
@@ -80,7 +80,7 @@ class QueryTest(TestBase):
                     ret[c.key] = row[c]
             return ret
 
-        if testing.against('firebird', 'postgres', 'oracle', 'mssql'):
+        if testing.against('firebird', 'postgres', 'oracle'): #, 'mssql'):
             test_engines = [
                 engines.testing_engine(options={'implicit_returning':False}),
                 engines.testing_engine(options={'implicit_returning':True}),
@@ -148,6 +148,25 @@ class QueryTest(TestBase):
                 finally:
                     table.drop(bind=engine)
 
+    @testing.fails_on('sqlite', "sqlite autoincremnt doesn't work with composite pks")
+    def test_misordered_lastrow(self):
+        related = Table('related', metadata,
+            Column('id', Integer, primary_key=True)
+        )
+        t6 = Table("t6", metadata,
+            Column('manual_id', Integer, ForeignKey('related.id'), primary_key=True),
+            Column('auto_id', Integer, primary_key=True),
+        )
+
+        metadata.create_all()
+        r = related.insert().values(id=12).execute()
+        id = r.last_inserted_ids()[0]
+        assert id==12
+
+        r = t6.insert().values(manual_id=id).execute()
+        eq_(r.last_inserted_ids(), [12, 1])
+
+
     def test_row_iteration(self):
         users.insert().execute(
             {'user_id':7, 'user_name':'jack'},