From: Mike Bayer Date: Tue, 22 Nov 2011 23:46:45 +0000 (-0500) Subject: also add support for onupdate as we'd like this to fire off if an UPDATE actually X-Git-Tag: rel_0_7_4~53 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9c896906c7e4130ea11cf913dd50d29a9a3e1fa7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git also add support for onupdate as we'd like this to fire off if an UPDATE actually happens on the table --- diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 4b1b9bd5d7..7aee5da818 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1140,23 +1140,41 @@ class SQLCompiler(engine.Compiled): # special logic that only occurs for multi-table UPDATE # statements if extra_tables and stmt.parameters: + assert self.isupdate + affected_tables = set() for t in extra_tables: for c in t.c: if c in stmt.parameters: + affected_tables.add(t) check_columns[c.key] = c - - for c in check_columns.values(): - value = stmt.parameters[c] - if sql._is_literal(value): - value = self._create_crud_bind_param( - c, value, required=value is required) - elif c.primary_key and implicit_returning: - self.returning.append(c) - value = self.process(value.self_group()) - else: - self.postfetch.append(c) - value = self.process(value.self_group()) - values.append((c, value)) + value = stmt.parameters[c] + if sql._is_literal(value): + value = self._create_crud_bind_param( + c, value, required=value is required) + else: + self.postfetch.append(c) + value = self.process(value.self_group()) + values.append((c, value)) + # determine tables which are actually + # to be updated - process onupdate and + # server_onupdate for these + for t in affected_tables: + for c in t.c: + if c in stmt.parameters: + continue + elif c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + (c, self.process(c.onupdate.arg.self_group())) + ) + self.postfetch.append(c) + else: + values.append( + (c, self._create_crud_bind_param(c, None)) + ) + self.prefetch.append(c) + elif c.server_onupdate is not None: + self.postfetch.append(c) # iterating through columns at the top to maintain ordering. # otherwise we might iterate through individual sets of diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 2ea3d92a4d..8eccde999b 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -14,8 +14,6 @@ class _UpdateFromTestBase(object): Column('id', Integer, primary_key=True, test_needs_autoincrement=True), Column('name', String(30), nullable=False), - test_needs_acid=True, - test_needs_fk=True ) Table('addresses', metadata, @@ -24,8 +22,6 @@ class _UpdateFromTestBase(object): Column('user_id', None, ForeignKey('users.id')), Column('name', String(30), nullable=False), Column('email_address', String(50), nullable=False), - test_needs_acid=True, - test_needs_fk=True ) Table("dingalings", metadata, @@ -33,8 +29,6 @@ class _UpdateFromTestBase(object): test_needs_autoincrement=True), Column('address_id', None, ForeignKey('addresses.id')), Column('data', String(30)), - test_needs_acid=True, - test_needs_fk=True ) @classmethod @@ -222,3 +216,105 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (10, 'chuck') ] ) + +class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, fixtures.TablesTest): + @classmethod + def define_tables(cls, metadata): + Table('users', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('name', String(30), nullable=False), + Column('some_update', String(30), onupdate="im the update") + ) + + Table('addresses', metadata, + Column('id', Integer, primary_key=True, + test_needs_autoincrement=True), + Column('user_id', None, ForeignKey('users.id')), + Column('email_address', String(50), nullable=False), + ) + + @classmethod + def fixtures(cls): + return dict( + users = ( + ('id', 'name', 'some_update'), + (8, 'ed', 'value'), + (9, 'fred', 'value'), + ), + + addresses = ( + ('id', 'user_id', 'email_address'), + (2, 8, "ed@wood.com"), + (3, 8, "ed@bettyboop.com"), + (4, 9, "fred@fred.com") + ), + ) + + @testing.only_on('mysql', 'Multi table update') + def test_defaults_second_table(self): + users, addresses = self.tables.users, self.tables.addresses + ret = testing.db.execute( + addresses.update().\ + values({ + addresses.c.email_address:users.c.name, + users.c.name:'ed2' + }).\ + where(users.c.id==addresses.c.user_id).\ + where(users.c.name=='ed') + ) + eq_( + set(ret.prefetch_cols()), + set([users.c.some_update]) + ) + eq_( + testing.db.execute( + addresses.select().order_by(addresses.c.id)).fetchall(), + [ + (2, 8, "ed"), + (3, 8, "ed"), + (4, 9, "fred@fred.com") + ] + ) + eq_( + testing.db.execute( + users.select().order_by(users.c.id)).fetchall(), + [ + (8, 'ed2', 'im the update'), + (9, 'fred', 'value'), + ] + ) + + @testing.only_on('mysql', 'Multi table update') + def test_no_defaults_second_table(self): + users, addresses = self.tables.users, self.tables.addresses + ret = testing.db.execute( + addresses.update().\ + values({ + 'email_address':users.c.name, + }).\ + where(users.c.id==addresses.c.user_id).\ + where(users.c.name=='ed') + ) + eq_( + ret.prefetch_cols(),[] + ) + eq_( + testing.db.execute( + addresses.select().order_by(addresses.c.id)).fetchall(), + [ + (2, 8, "ed"), + (3, 8, "ed"), + (4, 9, "fred@fred.com") + ] + ) + # users table not actually updated, + # so no onupdate + eq_( + testing.db.execute( + users.select().order_by(users.c.id)).fetchall(), + [ + (8, 'ed', 'value'), + (9, 'fred', 'value'), + ] + )