]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
also add support for onupdate as we'd like this to fire off if an UPDATE actually
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Nov 2011 23:46:45 +0000 (18:46 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Nov 2011 23:46:45 +0000 (18:46 -0500)
happens on the table

lib/sqlalchemy/sql/compiler.py
test/sql/test_update.py

index 4b1b9bd5d770914fe54cf6364f16a9bfc749ebf0..7aee5da81869009e7c8dbd9cb2a5a46f42f3e2d5 100644 (file)
@@ -1140,23 +1140,41 @@ class SQLCompiler(engine.Compiled):
         # special logic that only occurs for multi-table UPDATE 
         # statements
         if extra_tables and stmt.parameters:
+            assert self.isupdate
+            affected_tables = set()
             for t in extra_tables:
                 for c in t.c:
                     if c in stmt.parameters:
+                        affected_tables.add(t)
                         check_columns[c.key] = c
-
-            for c in check_columns.values():
-                value = stmt.parameters[c]
-                if sql._is_literal(value):
-                    value = self._create_crud_bind_param(
-                                    c, value, required=value is required)
-                elif c.primary_key and implicit_returning:
-                    self.returning.append(c)
-                    value = self.process(value.self_group())
-                else:
-                    self.postfetch.append(c)
-                    value = self.process(value.self_group())
-                values.append((c, value))
+                        value = stmt.parameters[c]
+                        if sql._is_literal(value):
+                            value = self._create_crud_bind_param(
+                                            c, value, required=value is required)
+                        else:
+                            self.postfetch.append(c)
+                            value = self.process(value.self_group())
+                        values.append((c, value))
+            # determine tables which are actually
+            # to be updated - process onupdate and 
+            # server_onupdate for these
+            for t in affected_tables:
+                for c in t.c:
+                    if c in stmt.parameters:
+                        continue
+                    elif c.onupdate is not None and not c.onupdate.is_sequence:
+                        if c.onupdate.is_clause_element:
+                            values.append(
+                                (c, self.process(c.onupdate.arg.self_group()))
+                            )
+                            self.postfetch.append(c)
+                        else:
+                            values.append(
+                                (c, self._create_crud_bind_param(c, None))
+                            )
+                            self.prefetch.append(c)
+                    elif c.server_onupdate is not None:
+                        self.postfetch.append(c)
 
         # iterating through columns at the top to maintain ordering.
         # otherwise we might iterate through individual sets of 
index 2ea3d92a4d18b090540af12e9d407a20768119db..8eccde999b33ff7f9dd2555ba3a9230bbce8a7f2 100644 (file)
@@ -14,8 +14,6 @@ class _UpdateFromTestBase(object):
               Column('id', Integer, primary_key=True, 
                             test_needs_autoincrement=True),
               Column('name', String(30), nullable=False),
-              test_needs_acid=True,
-              test_needs_fk=True
         )
 
         Table('addresses', metadata,
@@ -24,8 +22,6 @@ class _UpdateFromTestBase(object):
               Column('user_id', None, ForeignKey('users.id')),
               Column('name', String(30), nullable=False),
               Column('email_address', String(50), nullable=False),
-              test_needs_acid=True,
-              test_needs_fk=True
         )
 
         Table("dingalings", metadata,
@@ -33,8 +29,6 @@ class _UpdateFromTestBase(object):
                             test_needs_autoincrement=True),
               Column('address_id', None, ForeignKey('addresses.id')),
               Column('data', String(30)),
-              test_needs_acid=True,
-              test_needs_fk=True
         )
 
     @classmethod
@@ -222,3 +216,105 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
                 (10, 'chuck')
             ]
         )
+
+class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase, fixtures.TablesTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('users', metadata,
+              Column('id', Integer, primary_key=True, 
+                            test_needs_autoincrement=True),
+              Column('name', String(30), nullable=False),
+              Column('some_update', String(30), onupdate="im the update")
+        )
+
+        Table('addresses', metadata,
+              Column('id', Integer, primary_key=True, 
+                            test_needs_autoincrement=True),
+              Column('user_id', None, ForeignKey('users.id')),
+              Column('email_address', String(50), nullable=False),
+        )
+
+    @classmethod
+    def fixtures(cls):
+        return dict(
+            users = (
+                ('id', 'name', 'some_update'),
+                (8, 'ed', 'value'),
+                (9, 'fred', 'value'),
+            ),
+
+            addresses = (
+                ('id', 'user_id', 'email_address'),
+                (2, 8, "ed@wood.com"),
+                (3, 8, "ed@bettyboop.com"),
+                (4, 9, "fred@fred.com")
+            ),
+        )
+
+    @testing.only_on('mysql', 'Multi table update')
+    def test_defaults_second_table(self):
+        users, addresses = self.tables.users, self.tables.addresses
+        ret = testing.db.execute(
+            addresses.update().\
+                values({
+                        addresses.c.email_address:users.c.name, 
+                        users.c.name:'ed2'
+                }).\
+                where(users.c.id==addresses.c.user_id).\
+                where(users.c.name=='ed')
+        )
+        eq_(
+            set(ret.prefetch_cols()),
+            set([users.c.some_update])
+        )
+        eq_(
+            testing.db.execute(
+                addresses.select().order_by(addresses.c.id)).fetchall(),
+            [
+                (2, 8, "ed"),
+                (3, 8, "ed"),
+                (4, 9, "fred@fred.com")
+            ]
+        )
+        eq_(
+            testing.db.execute(
+                users.select().order_by(users.c.id)).fetchall(),
+            [
+                (8, 'ed2', 'im the update'),
+                (9, 'fred', 'value'),
+            ]
+        )
+
+    @testing.only_on('mysql', 'Multi table update')
+    def test_no_defaults_second_table(self):
+        users, addresses = self.tables.users, self.tables.addresses
+        ret = testing.db.execute(
+            addresses.update().\
+                values({
+                        'email_address':users.c.name, 
+                }).\
+                where(users.c.id==addresses.c.user_id).\
+                where(users.c.name=='ed')
+        )
+        eq_(
+            ret.prefetch_cols(),[]
+        )
+        eq_(
+            testing.db.execute(
+                addresses.select().order_by(addresses.c.id)).fetchall(),
+            [
+                (2, 8, "ed"),
+                (3, 8, "ed"),
+                (4, 9, "fred@fred.com")
+            ]
+        )
+        # users table not actually updated, 
+        # so no onupdate
+        eq_(
+            testing.db.execute(
+                users.select().order_by(users.c.id)).fetchall(),
+            [
+                (8, 'ed', 'value'),
+                (9, 'fred', 'value'),
+            ]
+        )