From: Mike Bayer Date: Sun, 17 Jun 2007 00:57:06 +0000 (+0000) Subject: - merged last_inserted_ids() fix from trunk [changeset:2743] X-Git-Tag: rel_0_4_6~193 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=37762aad5c3d4e43bbf01fbd798d6d010672fbdb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - merged last_inserted_ids() fix from trunk [changeset:2743] --- diff --git a/CHANGES b/CHANGES index e65b09f68b..b7e3036a2b 100644 --- a/CHANGES +++ b/CHANGES @@ -44,9 +44,25 @@ - added "explcit" create/drop/execute support for sequences (i.e. you can pass a "connectable" to each of those methods on Sequence) - + - result.last_inserted_ids() should return a list that is identically + sized to the primary key constraint of the table. values that were + "passively" created and not available via cursor.lastrowid will be None. + - long-identifier detection fixed to use > rather than >= for + max ident length [ticket:589] + - fixed bug where selectable.corresponding_column(selectable.c.col) + would not return selectable.c.col, if the selectable is a join + of a table and another join involving the same table. messed + up ORM decision making [ticket:593] - mysql - added 'fields' to reserved words [ticket:590] + +- oracle + - datetime fixes: got subsecond TIMESTAMP to work [ticket:604], + added OracleDate which supports types.Date with only year/month/day +- sqlite + - sqlite better handles datetime/date/time objects mixed and matched + with various Date/Time/DateTime columns + - string PK column inserts dont get overwritten with OID [ticket:603] - extensions - proxyengine is temporarily removed, pending an actually working replacement. @@ -54,7 +70,7 @@ SelectResultsExt still exist but just return a slightly modified Query object for backwards-compatibility. join_to() method from SelectResults isn't present anymore, need to use join(). - + 0.3.8 - engines - added detach() to Connection, allows underlying DBAPI connection diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 4336296dd9..2b6808eaca 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -271,13 +271,14 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.compiled.statement.table.fullname) self.IINSERT = False elif self.HASIDENT: - if self.dialect.use_scope_identity: - self.cursor.execute("SELECT scope_identity() AS lastrowid") - else: - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - self._last_inserted_ids = [int(row[0])] - # print "LAST ROW ID", self._last_inserted_ids + if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + if self.dialect.use_scope_identity: + self.cursor.execute("SELECT scope_identity() AS lastrowid") + else: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:] + # print "LAST ROW ID", self._last_inserted_ids self.HASIDENT = False super(MSSQLExecutionContext, self).post_exec() diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 63ce05eb68..e45536a756 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -945,7 +945,8 @@ def descriptor(): class MySQLExecutionContext(default.DefaultExecutionContext): def post_exec(self): if self.compiled.isinsert: - self._last_inserted_ids = [self.cursor.lastrowid] + if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] def is_select(self): return re.match(r'SELECT|SHOW|DESCRIBE', self.statement.lstrip(), re.I) is not None diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 53525582f6..15a30bafd1 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -139,8 +139,8 @@ def descriptor(): class SQLiteExecutionContext(default.DefaultExecutionContext): def post_exec(self): if self.compiled.isinsert: - self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] - super(SQLiteExecutionContext, self).post_exec() + if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + self._last_inserted_ids = [self.cursor.lastrowid] + self._last_inserted_ids[1:] def is_select(self): return re.match(r'SELECT|PRAGMA', self.statement.lstrip(), re.I) is not None diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index eca0faf914..64c5de2b0a 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -285,7 +285,6 @@ class DefaultExecutionContext(base.ExecutionContext): self._lastrow_has_defaults = False for param in plist: last_inserted_ids = [] - need_lastrowid=False # check the "default" status of each column in the table for c in self.compiled.statement.table.c: # check if it will be populated by a SQL clause - we'll need that @@ -313,6 +312,8 @@ class DefaultExecutionContext(base.ExecutionContext): # our last_inserted_ids list. elif c.primary_key: last_inserted_ids.append(param.get_processed(c.key)) + # TODO: we arent accounting for executemany() situations + # here (hard to do since lastrowid doesnt support it either) self._last_inserted_ids = last_inserted_ids self._last_inserted_params = param elif self.compiled.isupdate: diff --git a/test/sql/query.py b/test/sql/query.py index 8597fbe756..a47fcdef10 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -36,25 +36,39 @@ class QueryTest(PersistTest): users.insert().execute(user_id = 7, user_name = 'jack') assert users.count().scalar() == 1 - @testbase.unsupported('sqlite') + def testupdate(self): + + users.insert().execute(user_id = 7, user_name = 'jack') + assert users.count().scalar() == 1 + + users.update(users.c.user_id == 7).execute(user_name = 'fred') + assert users.select(users.c.user_id==7).execute().fetchone()['user_name'] == 'fred' + def test_lastrow_accessor(self): """test the last_inserted_ids() and lastrow_has_id() functions""" - + def insert_values(table, values): + """insert a row into a table, return the full list of values INSERTed including defaults + that fired off on the DB side. + + detects rows that had defaults and post-fetches. + """ + result = table.insert().execute(**values) ret = values.copy() - + for col, id in zip(table.primary_key, result.last_inserted_ids()): ret[col.key] = id - + if result.lastrow_has_defaults(): criterion = and_(*[col==id for col, id in zip(table.primary_key, result.last_inserted_ids())]) row = table.select(criterion).execute().fetchone() ret.update(row) return ret - - for table, values, assertvalues in [ + + for supported, table, values, assertvalues in [ ( + {'unsupported':['sqlite']}, Table("t1", metadata, Column('id', Integer, primary_key=True), Column('foo', String(30), primary_key=True)), @@ -62,6 +76,7 @@ class QueryTest(PersistTest): {'id':1, 'foo':'hi'} ), ( + {'unsupported':['sqlite']}, Table("t2", metadata, Column('id', Integer, primary_key=True), Column('foo', String(30), primary_key=True), @@ -70,21 +85,43 @@ class QueryTest(PersistTest): {'foo':'hi'}, {'id':1, 'foo':'hi', 'bar':'hi'} ), - + ( + {'unsupported':[]}, + Table("t3", metadata, + Column("id", String(40), primary_key=True), + Column('foo', String(30), primary_key=True), + Column("bar", String(30)) + ), + {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"}, + {'id':'hi', 'foo':'thisisfoo', 'bar':"thisisbar"} + ), + ( + {'unsupported':[]}, + Table("t4", metadata, + Column('id', Integer, primary_key=True), + Column('foo', String(30), primary_key=True), + Column('bar', String(30), PassiveDefault('hi')) + ), + {'foo':'hi', 'id':1}, + {'id':1, 'foo':'hi', 'bar':'hi'} + ), + ( + {'unsupported':[]}, + Table("t5", metadata, + Column('id', String(10), primary_key=True), + Column('bar', String(30), PassiveDefault('hi')) + ), + {'id':'id1'}, + {'id':'id1', 'bar':'hi'}, + ), ]: + if testbase.db.name in supported['unsupported']: + continue try: table.create() - assert insert_values(table, values) == assertvalues + assert insert_values(table, values) == assertvalues, repr(values) + " " + repr(assertvalues) finally: table.drop() - - def testupdate(self): - - users.insert().execute(user_id = 7, user_name = 'jack') - assert users.count().scalar() == 1 - - users.update(users.c.user_id == 7).execute(user_name = 'fred') - assert users.select(users.c.user_id==7).execute().fetchone()['user_name'] == 'fred' def testrowiteration(self): users.insert().execute(user_id = 7, user_name = 'jack')