# No offset provided, so just use the limit
return ' \n LIMIT %s' % (self.process(sql.literal(limit)),)
- def visit_update(self, update_stmt):
- self.stack.append({'from': set([update_stmt.table])})
-
- self.isupdate = True
- colparams = self._get_colparams(update_stmt)
-
- text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \
- " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams])
-
- if update_stmt._whereclause is not None:
- text += " WHERE " + self.process(update_stmt._whereclause)
-
+ def update_limit_clause(self, update_stmt):
limit = update_stmt.kwargs.get('%s_limit' % self.dialect.name, None)
if limit:
- text += " LIMIT %s" % limit
+ return "LIMIT %s" % limit
+ else:
+ return None
- self.stack.pop(-1)
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+ return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw)
+ for t in [from_table] + list(extra_froms))
- return text
+ def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
+ return None
+
+ def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms):
+ return bool(extra_froms)
# ug. "InnoDB needs indexes on foreign keys and referenced keys [...].
# Starting with MySQL 4.1.2, these indexes are created automatically.
return text
- def visit_update(self, update_stmt):
+ def update_limit_clause(self, update_stmt):
+ return None
+
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+ return self.preparer.format_table(from_table)
+
+ def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
+ return "FROM " + ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) for t in extra_froms)
+
+ def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms):
+ return False
+
+ def visit_update(self, update_stmt, **kw):
self.stack.append({'from': set([update_stmt.table])})
self.isupdate = True
- colparams = self._get_colparams(update_stmt)
- text = "UPDATE " + self.preparer.format_table(update_stmt.table)
+ if update_stmt._whereclause is not None:
+ extra_froms = set(update_stmt._whereclause._from_objects).\
+ difference([update_stmt.table])
+ else:
+ extra_froms = set()
+
+ colparams = self._get_colparams(update_stmt, extra_froms)
+
+ #for c in colparams:
+ # if hasattr(c[1], '_from_objects'):
+ # extra_froms.update(c[1]._from_objects)
- text += ' SET ' + \
+ text = "UPDATE " + self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw)
+
+ if self.should_render_table_with_col_in_update(update_stmt, update_stmt.table, extra_froms):
+ text += ' SET ' + \
+ ', '.join(
+ self.visit_column(c[0]) +
+ '=' + c[1]
+ for c in colparams
+ )
+ else:
+ text += ' SET ' + \
', '.join(
self.preparer.quote(c[0].name, c[0].quote) +
'=' + c[1]
text += " " + self.returning_clause(
update_stmt, update_stmt._returning)
+ if extra_froms:
+ extra_from_text = self.update_from_clause(update_stmt, update_stmt.table, extra_froms, **kw)
+ if extra_from_text:
+ text += " " + extra_from_text
+
if update_stmt._whereclause is not None:
text += " WHERE " + self.process(update_stmt._whereclause)
+ limit_clause = self.update_limit_clause(update_stmt)
+ if limit_clause:
+ text += " " + limit_clause
+
if self.returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
update_stmt, update_stmt._returning)
return bindparam._compiler_dispatch(self)
- def _get_colparams(self, stmt):
+ def _get_colparams(self, stmt, extra_tables=None):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
(
implicit_returning or
not postfetch_lastrowid or
- c is not stmt.table._autoincrement_column
+ c is not t._autoincrement_column
):
if implicit_returning:
self.returning.append(c)
else:
if c.default is not None or \
- c is stmt.table._autoincrement_column and (
+ c is t._autoincrement_column and (
self.dialect.supports_sequences or
self.dialect.preexecute_autoincrement_sequences
):
"(SELECT myothertable.othername FROM myothertable "
"WHERE myothertable.otherid = mytable.myid)")
+ # test correlated FROM implicit in WHERE and SET clauses
+ u = table1.update().values(name=table2.c.othername)\
+ .where(table2.c.otherid == table1.c.myid)
+ self.assert_compile(u,
+ "UPDATE mytable SET name=myothertable.othername "
+ "FROM myothertable WHERE myothertable.otherid = mytable.myid")
+ u = table1.update().values(name='foo')\
+ .where(table2.c.otherid == table1.c.myid)
+ self.assert_compile(u,
+ "UPDATE mytable SET name=:name "
+ "FROM myothertable WHERE myothertable.otherid = mytable.myid")
+
def test_delete(self):
self.assert_compile(
delete(table1, table1.c.myid == 7),