From 409acb42086bb82bdfc784dafef5a6fa50afd0e0 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 28 Apr 2006 23:31:59 +0000 Subject: [PATCH] fix for [ticket:169], moves the creation of "default" parameters more accurately where theyre supposed to be --- lib/sqlalchemy/ansisql.py | 37 ++++++++++++++++------------ lib/sqlalchemy/databases/postgres.py | 4 +-- test/defaults.py | 14 +++++++++++ 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 7aefea0ba2..a344e017c7 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -146,7 +146,6 @@ class ANSICompiler(sql.Compiled): continue d.set_parameter(b.key, value, b) - #print "FROM", params, "TO", d return d def get_named_params(self, parameters): @@ -425,26 +424,26 @@ class ANSICompiler(sql.Compiled): " ON " + self.get_str(join.onclause)) self.strings[join] = self.froms[join] - def visit_insert_column_default(self, column, default): + def visit_insert_column_default(self, column, default, parameters): """called when visiting an Insert statement, for each column in the table that contains a ColumnDefault object. adds a blank 'placeholder' parameter so the Insert gets compiled with this column's name in its column and VALUES clauses.""" - self.parameters.setdefault(column.key, None) + parameters.setdefault(column.key, None) - def visit_update_column_default(self, column, default): + def visit_update_column_default(self, column, default, parameters): """called when visiting an Update statement, for each column in the table that contains a ColumnDefault object as an onupdate. adds a blank 'placeholder' parameter so the Update gets compiled with this column's name as one of its SET clauses.""" - self.parameters.setdefault(column.key, None) + parameters.setdefault(column.key, None) - def visit_insert_sequence(self, column, sequence): + def visit_insert_sequence(self, column, sequence, parameters): """called when visiting an Insert statement, for each column in the table that contains a Sequence object. Overridden by compilers that support sequences to place a blank 'placeholder' parameter, so the Insert gets compiled with this column's name in its column and VALUES clauses.""" pass - def visit_insert_column(self, column): + def visit_insert_column(self, column, parameters): """called when visiting an Insert statement, for each column in the table that is a NULL insert into the table. Overridden by compilers who disallow NULL columns being set in an Insert where there is a default value on the column @@ -454,25 +453,27 @@ class ANSICompiler(sql.Compiled): def visit_insert(self, insert_stmt): # scan the table's columns for defaults that have to be pre-set for an INSERT # add these columns to the parameter list via visit_insert_XXX methods + default_params = {} class DefaultVisitor(schema.SchemaVisitor): def visit_column(s, c): - self.visit_insert_column(c) + self.visit_insert_column(c, default_params) def visit_column_default(s, cd): - self.visit_insert_column_default(c, cd) + self.visit_insert_column_default(c, cd, default_params) def visit_sequence(s, seq): - self.visit_insert_sequence(c, seq) + self.visit_insert_sequence(c, seq, default_params) vis = DefaultVisitor() for c in insert_stmt.table.c: if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): c.accept_schema_visitor(vis) self.isinsert = True - colparams = self._get_colparams(insert_stmt) + colparams = self._get_colparams(insert_stmt, default_params) def create_param(p): if isinstance(p, sql.BindParamClause): self.binds[p.key] = p - self.binds[p.shortname] = p + if p.shortname is not None: + self.binds[p.shortname] = p return self.bindparam_string(p.key) else: p.accept_visitor(self) @@ -483,22 +484,23 @@ class ANSICompiler(sql.Compiled): text = ("INSERT INTO " + insert_stmt.table.fullname + " (" + string.join([c[0].name for c in colparams], ', ') + ")" + " VALUES (" + string.join([create_param(c[1]) for c in colparams], ', ') + ")") - + self.strings[insert_stmt] = text def visit_update(self, update_stmt): # scan the table's columns for onupdates that have to be pre-set for an UPDATE # add these columns to the parameter list via visit_update_XXX methods + default_params = {} class OnUpdateVisitor(schema.SchemaVisitor): def visit_column_onupdate(s, cd): - self.visit_update_column_default(c, cd) + self.visit_update_column_default(c, cd, default_params) vis = OnUpdateVisitor() for c in update_stmt.table.c: if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)): c.accept_schema_visitor(vis) self.isupdate = True - colparams = self._get_colparams(update_stmt) + colparams = self._get_colparams(update_stmt, default_params) def create_param(p): if isinstance(p, sql.BindParamClause): self.binds[p.key] = p @@ -519,7 +521,7 @@ class ANSICompiler(sql.Compiled): self.strings[update_stmt] = text - def _get_colparams(self, stmt): + def _get_colparams(self, stmt, default_params): """determines the VALUES or SET clause for an INSERT or UPDATE clause based on the arguments specified to this ANSICompiler object (i.e., the execute() or compile() method clause object): @@ -550,6 +552,9 @@ class ANSICompiler(sql.Compiled): for k, v in stmt.parameters.iteritems(): parameters.setdefault(k, v) + for k, v in default_params.iteritems(): + parameters.setdefault(k, v) + # now go thru compiled params, get the Column object for each key d = {} for key, value in parameters.iteritems(): diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 8a063ca06e..19a703c0ce 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -291,13 +291,13 @@ class PGSQLEngine(ansisql.ANSISQLEngine): class PGCompiler(ansisql.ANSICompiler): - def visit_insert_column(self, column): + def visit_insert_column(self, column, parameters): # Postgres advises against OID usage and turns it off in 8.1, # effectively making cursor.lastrowid # useless, effectively making reliance upon SERIAL useless. # so all column primary key inserts must be explicitly present if column.primary_key: - self.parameters[column.key] = None + parameters[column.key] = None def limit_clause(self, select): text = "" diff --git a/test/defaults.py b/test/defaults.py index 0d91d12a48..096355826f 100644 --- a/test/defaults.py +++ b/test/defaults.py @@ -92,6 +92,12 @@ class DefaultTest(PersistTest): l = t.select().execute() self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec), (52, 'imthedefault', f, ts, ts, ctexec), (53, 'imthedefault', f, ts, ts, ctexec)]) + def testinsertvalues(self): + t.insert(values={'col3':50}).execute() + l = t.select().execute() + self.assert_(l.fetchone()['col3'] == 50) + + def testupdate(self): t.insert().execute() pk = t.engine.last_inserted_ids()[0] @@ -103,6 +109,14 @@ class DefaultTest(PersistTest): self.assert_(l == (pk, 'im the update', f2, None, None, ctexec)) # mysql/other db's return 0 or 1 for count(1) self.assert_(14 <= f2 <= 15) + + def testupdatevalues(self): + t.insert().execute() + pk = t.engine.last_inserted_ids()[0] + t.update(t.c.col1==pk, values={'col3': 55}).execute() + l = t.select(t.c.col1==pk).execute() + l = l.fetchone() + self.assert_(l['col3'] == 55) class SequenceTest(PersistTest): -- 2.47.2