From c9fdf9a455445643a696c241dcb92c6e058480ad Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 1 Aug 2009 00:36:00 +0000 Subject: [PATCH] - Databases which rely upon postfetch of "last inserted id" to get at a generated sequence value (i.e. MySQL, MS-SQL) now work correctly when there is a composite primary key where the "autoincrement" column is not the first primary key column in the table. --- 06CHANGES | 4 ++++ lib/sqlalchemy/dialects/mssql/base.py | 33 ++++++--------------------- lib/sqlalchemy/engine/default.py | 22 ++++++++++-------- lib/sqlalchemy/schema.py | 9 ++++++++ test/sql/test_query.py | 21 ++++++++++++++++- 5 files changed, 52 insertions(+), 37 deletions(-) diff --git a/06CHANGES b/06CHANGES index 141f834a26..bc2eb56f65 100644 --- a/06CHANGES +++ b/06CHANGES @@ -21,6 +21,10 @@ (a version number check is performed). This occurs if no end-user returning() was specified. + - Databases which rely upon postfetch of "last inserted id" to get at a + generated sequence value (i.e. MySQL, MS-SQL) now work correctly + when there is a composite primary key where the "autoincrement" column + is not the first primary key column in the table. - engines - transaction isolation level may be specified with diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index f21f53fd22..a521932970 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -823,42 +823,22 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): def visit_SQL_VARIANT(self, type_): return 'SQL_VARIANT' -def _has_implicit_sequence(column): - return column.primary_key and \ - column.autoincrement and \ - isinstance(column.type, sqltypes.Integer) and \ - not column.foreign_keys and \ - ( - column.default is None or - ( - isinstance(column.default, sa_schema.Sequence) and - column.default.optional) - ) - -def _table_sequence_column(tbl): - if not hasattr(tbl, '_ms_has_sequence'): - tbl._ms_has_sequence = None - for column in tbl.c: - if getattr(column, 'sequence', False) or _has_implicit_sequence(column): - tbl._ms_has_sequence = column - break - return tbl._ms_has_sequence - class MSExecutionContext(default.DefaultExecutionContext): _enable_identity_insert = False _select_lastrowid = False _result_proxy = None + _lastrowid = None def pre_exec(self): """Activate IDENTITY_INSERT if needed.""" if self.isinsert: tbl = self.compiled.statement.table - seq_column = _table_sequence_column(tbl) - insert_has_sequence = bool(seq_column) + seq_column = tbl._autoincrement_column + insert_has_sequence = seq_column is not None if insert_has_sequence: - self._enable_identity_insert = tbl._ms_has_sequence.key in self.compiled_parameters[0] + self._enable_identity_insert = seq_column.key in self.compiled_parameters[0] else: self._enable_identity_insert = False @@ -1094,7 +1074,7 @@ class MSDDLCompiler(compiler.DDLCompiler): if not column.table: raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL") - seq_col = _table_sequence_column(column.table) + seq_col = column.table._autoincrement_column # install a IDENTITY Sequence if we have an implicit IDENTITY column if seq_col is column: @@ -1147,7 +1127,8 @@ class MSDialect(default.DefaultDialect): preexecute_pk_sequences = True supports_unicode_binds = True - + postfetch_lastrowid = True + server_version_info = () statement_compiler = MSSQLCompiler diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 14e73e5886..45918618d4 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -182,7 +182,6 @@ class DefaultDialect(base.Dialect): class DefaultExecutionContext(base.ExecutionContext): - _lastrowid = None def __init__(self, dialect, connection, compiled_sql=None, compiled_ddl=None, statement=None, parameters=None): self.dialect = dialect @@ -385,12 +384,15 @@ class DefaultExecutionContext(base.ExecutionContext): def post_insert(self): if self.dialect.postfetch_lastrowid and \ - self._lastrowid is None and \ (not len(self._last_inserted_ids) or \ - self._last_inserted_ids[0] is None): + None in self._last_inserted_ids): + + table = self.compiled.statement.table + lastrowid = self.get_lastrowid() + self._last_inserted_ids = [c is table._autoincrement_column and lastrowid or v + for c, v in zip(table.primary_key, self._last_inserted_ids) + ] - self._lastrowid = self.get_lastrowid() - def last_inserted_ids(self, resultproxy): if not self.isinsert: raise exc.InvalidRequestError("Statement is not an insert() expression construct.") @@ -398,16 +400,15 @@ class DefaultExecutionContext(base.ExecutionContext): if self.dialect.implicit_returning and \ not self.compiled.statement._returning and \ not resultproxy.closed: - + + table = self.compiled.statement.table row = resultproxy.first() self._last_inserted_ids = [v is not None and v or row[c] - for c, v in zip(self.compiled.statement.table.primary_key, self._last_inserted_ids) + for c, v in zip(table.primary_key, self._last_inserted_ids) ] return self._last_inserted_ids - elif self._lastrowid is not None: - return [self._lastrowid] + self._last_inserted_ids[1:] else: return self._last_inserted_ids @@ -497,7 +498,8 @@ class DefaultExecutionContext(base.ExecutionContext): compiled_parameters[c.key] = val if self.isinsert: - self._last_inserted_ids = [compiled_parameters.get(c.key, None) for c in self.compiled.statement.table.primary_key] + self._last_inserted_ids = [compiled_parameters.get(c.key, None) + for c in self.compiled.statement.table.primary_key] self._last_inserted_params = compiled_parameters else: self._last_updated_params = compiled_parameters diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 346bf884af..a6961aab50 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -280,6 +280,15 @@ class Table(SchemaItem, expression.TableClause): for c in pk.columns: c.primary_key = True + @util.memoized_property + def _autoincrement_column(self): + for col in self.primary_key: + if col.autoincrement and \ + isinstance(col.type, types.Integer) and \ + not col.foreign_keys: + + return col + @property def key(self): return _get_table_key(self.name, self.schema) diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 37030c94f4..979c148e4a 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -80,7 +80,7 @@ class QueryTest(TestBase): ret[c.key] = row[c] return ret - if testing.against('firebird', 'postgres', 'oracle', 'mssql'): + if testing.against('firebird', 'postgres', 'oracle'): #, 'mssql'): test_engines = [ engines.testing_engine(options={'implicit_returning':False}), engines.testing_engine(options={'implicit_returning':True}), @@ -148,6 +148,25 @@ class QueryTest(TestBase): finally: table.drop(bind=engine) + @testing.fails_on('sqlite', "sqlite autoincremnt doesn't work with composite pks") + def test_misordered_lastrow(self): + related = Table('related', metadata, + Column('id', Integer, primary_key=True) + ) + t6 = Table("t6", metadata, + Column('manual_id', Integer, ForeignKey('related.id'), primary_key=True), + Column('auto_id', Integer, primary_key=True), + ) + + metadata.create_all() + r = related.insert().values(id=12).execute() + id = r.last_inserted_ids()[0] + assert id==12 + + r = t6.insert().values(manual_id=id).execute() + eq_(r.last_inserted_ids(), [12, 1]) + + def test_row_iteration(self): users.insert().execute( {'user_id':7, 'user_name':'jack'}, -- 2.47.3