From: Mike Bayer Date: Sun, 28 Oct 2007 21:28:53 +0000 (+0000) Subject: - fixed INSERT statements w.r.t. primary key columns that have SQL-expression X-Git-Tag: rel_0_4_1~100 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=bbebcdf8f526226d2d64a91dfc306086fc9873f4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - fixed INSERT statements w.r.t. primary key columns that have SQL-expression based default generators on them; SQL expression executes inline as normal but will not trigger a "postfetch" condition for the column, for those DB's who provide it via cursor.lastrowid --- diff --git a/CHANGES b/CHANGES index 718e752862..576aecf5d6 100644 --- a/CHANGES +++ b/CHANGES @@ -20,6 +20,11 @@ CHANGES - sqlite will reflect "DECIMAL" as a numeric column. +- fixed INSERT statements w.r.t. primary key columns that have SQL-expression + based default generators on them; SQL expression executes inline as normal + but will not trigger a "postfetch" condition for the column, for those DB's + who provide it via cursor.lastrowid + - Renamed the Dialect attribute 'preexecute_sequences' to 'preexecute_pk_sequences'. An attribute proxy is in place for out-of-tree dialects using the old name. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index e662f8e99d..43b6fb15b5 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -664,8 +664,7 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): values.append((c, value)) elif isinstance(c, schema.Column): if self.isinsert: - if (c.primary_key and self.dialect.preexecute_pk_sequences - and not self.inline): + if (c.primary_key and self.dialect.preexecute_pk_sequences and not self.inline): if (((isinstance(c.default, schema.Sequence) and not c.default.optional) or not self.dialect.supports_pk_autoincrement) or @@ -676,17 +675,21 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): elif isinstance(c.default, schema.ColumnDefault): if isinstance(c.default.arg, sql.ClauseElement): values.append((c, self.process(c.default.arg.self_group()))) - self.postfetch.add(c) + if not c.primary_key: + # dont add primary key column to postfetch + self.postfetch.add(c) else: values.append((c, create_bind_param(c, None))) self.prefetch.add(c) elif isinstance(c.default, schema.PassiveDefault): - self.postfetch.add(c) + if not c.primary_key: + self.postfetch.add(c) elif isinstance(c.default, schema.Sequence): proc = self.process(c.default) if proc is not None: values.append((c, proc)) - self.postfetch.add(c) + if not c.primary_key: + self.postfetch.add(c) elif self.isupdate: if isinstance(c.onupdate, schema.ColumnDefault): if isinstance(c.onupdate.arg, sql.ClauseElement): diff --git a/test/sql/defaults.py b/test/sql/defaults.py index c29ffa3b39..da23a67a21 100644 --- a/test/sql/defaults.py +++ b/test/sql/defaults.py @@ -243,6 +243,34 @@ class DefaultTest(PersistTest): finally: testbase.db.execute("drop table speedy_users", None) +class PKDefaultTest(PersistTest): + def setUpAll(self): + global metadata, t1, t2 + + metadata = MetaData(testbase.db) + + t2 = Table('t2', metadata, + Column('nextid', Integer)) + + t1 = Table('t1', metadata, + Column('id', Integer, primary_key=True, default=select([func.max(t2.c.nextid)]).as_scalar()), + Column('data', String(30))) + + metadata.create_all() + + def tearDownAll(self): + metadata.drop_all() + + def test_basic(self): + t2.insert().execute(nextid=1) + r = t1.insert().execute(data='hi') + assert r.last_inserted_ids() == [1] + + t2.insert().execute(nextid=2) + r = t1.insert().execute(data='there') + assert r.last_inserted_ids() == [2] + + class AutoIncrementTest(PersistTest): def setUp(self): global aitable, aimeta