continue
d.set_parameter(b.key, value, b)
- #print "FROM", params, "TO", d
return d
def get_named_params(self, parameters):
" ON " + self.get_str(join.onclause))
self.strings[join] = self.froms[join]
- def visit_insert_column_default(self, column, default):
+ def visit_insert_column_default(self, column, default, parameters):
"""called when visiting an Insert statement, for each column in the table that
contains a ColumnDefault object. adds a blank 'placeholder' parameter so the
Insert gets compiled with this column's name in its column and VALUES clauses."""
- self.parameters.setdefault(column.key, None)
+ parameters.setdefault(column.key, None)
- def visit_update_column_default(self, column, default):
+ def visit_update_column_default(self, column, default, parameters):
"""called when visiting an Update statement, for each column in the table that
contains a ColumnDefault object as an onupdate. adds a blank 'placeholder' parameter so the
Update gets compiled with this column's name as one of its SET clauses."""
- self.parameters.setdefault(column.key, None)
+ parameters.setdefault(column.key, None)
- def visit_insert_sequence(self, column, sequence):
+ def visit_insert_sequence(self, column, sequence, parameters):
"""called when visiting an Insert statement, for each column in the table that
contains a Sequence object. Overridden by compilers that support sequences to place
a blank 'placeholder' parameter, so the Insert gets compiled with this column's
name in its column and VALUES clauses."""
pass
- def visit_insert_column(self, column):
+ def visit_insert_column(self, column, parameters):
"""called when visiting an Insert statement, for each column in the table
that is a NULL insert into the table. Overridden by compilers who disallow
NULL columns being set in an Insert where there is a default value on the column
def visit_insert(self, insert_stmt):
# scan the table's columns for defaults that have to be pre-set for an INSERT
# add these columns to the parameter list via visit_insert_XXX methods
+ default_params = {}
class DefaultVisitor(schema.SchemaVisitor):
def visit_column(s, c):
- self.visit_insert_column(c)
+ self.visit_insert_column(c, default_params)
def visit_column_default(s, cd):
- self.visit_insert_column_default(c, cd)
+ self.visit_insert_column_default(c, cd, default_params)
def visit_sequence(s, seq):
- self.visit_insert_sequence(c, seq)
+ self.visit_insert_sequence(c, seq, default_params)
vis = DefaultVisitor()
for c in insert_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
c.accept_schema_visitor(vis)
self.isinsert = True
- colparams = self._get_colparams(insert_stmt)
+ colparams = self._get_colparams(insert_stmt, default_params)
def create_param(p):
if isinstance(p, sql.BindParamClause):
self.binds[p.key] = p
- self.binds[p.shortname] = p
+ if p.shortname is not None:
+ self.binds[p.shortname] = p
return self.bindparam_string(p.key)
else:
p.accept_visitor(self)
text = ("INSERT INTO " + insert_stmt.table.fullname + " (" + string.join([c[0].name for c in colparams], ', ') + ")" +
" VALUES (" + string.join([create_param(c[1]) for c in colparams], ', ') + ")")
-
+
self.strings[insert_stmt] = text
def visit_update(self, update_stmt):
# scan the table's columns for onupdates that have to be pre-set for an UPDATE
# add these columns to the parameter list via visit_update_XXX methods
+ default_params = {}
class OnUpdateVisitor(schema.SchemaVisitor):
def visit_column_onupdate(s, cd):
- self.visit_update_column_default(c, cd)
+ self.visit_update_column_default(c, cd, default_params)
vis = OnUpdateVisitor()
for c in update_stmt.table.c:
if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
c.accept_schema_visitor(vis)
self.isupdate = True
- colparams = self._get_colparams(update_stmt)
+ colparams = self._get_colparams(update_stmt, default_params)
def create_param(p):
if isinstance(p, sql.BindParamClause):
self.binds[p.key] = p
self.strings[update_stmt] = text
- def _get_colparams(self, stmt):
+ def _get_colparams(self, stmt, default_params):
"""determines the VALUES or SET clause for an INSERT or UPDATE
clause based on the arguments specified to this ANSICompiler object
(i.e., the execute() or compile() method clause object):
for k, v in stmt.parameters.iteritems():
parameters.setdefault(k, v)
+ for k, v in default_params.iteritems():
+ parameters.setdefault(k, v)
+
# now go thru compiled params, get the Column object for each key
d = {}
for key, value in parameters.iteritems():
l = t.select().execute()
self.assert_(l.fetchall() == [(51, 'imthedefault', f, ts, ts, ctexec), (52, 'imthedefault', f, ts, ts, ctexec), (53, 'imthedefault', f, ts, ts, ctexec)])
+ def testinsertvalues(self):
+ t.insert(values={'col3':50}).execute()
+ l = t.select().execute()
+ self.assert_(l.fetchone()['col3'] == 50)
+
+
def testupdate(self):
t.insert().execute()
pk = t.engine.last_inserted_ids()[0]
self.assert_(l == (pk, 'im the update', f2, None, None, ctexec))
# mysql/other db's return 0 or 1 for count(1)
self.assert_(14 <= f2 <= 15)
+
+ def testupdatevalues(self):
+ t.insert().execute()
+ pk = t.engine.last_inserted_ids()[0]
+ t.update(t.c.col1==pk, values={'col3': 55}).execute()
+ l = t.select(t.c.col1==pk).execute()
+ l = l.fetchone()
+ self.assert_(l['col3'] == 55)
class SequenceTest(PersistTest):