]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added inline UPDATE/INSERT clauses, settable as regular object attributes.
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 2 Aug 2007 04:24:02 +0000 (04:24 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 2 Aug 2007 04:24:02 +0000 (04:24 +0000)
      the clause gets executed inline during a flush().

CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/sql.py
test/orm/unitofwork.py

diff --git a/CHANGES b/CHANGES
index 9fcdb3c062a358ee25afed6c1a1f7d92fc2bbbc0..2b4f9d482903c983ee9237a56438de03743d8eff 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -94,6 +94,9 @@
       created/mapped to a single attribute, comprised of the values
       correponding to *columns [ticket:211]
 
+    - added inline UPDATE/INSERT clauses, settable as regular object attributes.
+      the clause gets executed inline during a flush().
+
     - improved support for custom column_property() attributes which
       feature correlated subqueries...work better with eager loading now.
 
index d50fb5a815c146e6606d1b53dbfe9958adf46769..acff8733253992fdbfa349bfe687ace20637d1e7 100644 (file)
@@ -607,7 +607,7 @@ class AttributeHistory(object):
                     self._deleted_items.append(a)
         else:
             self._current = [current]
-            if attr.is_equal(current, original):
+            if attr.is_equal(current, original) is True:
                 self._unchanged_items = [current]
                 self._added_items = []
                 self._deleted_items = []
index af7c9d4cffe2e1fc06be44f0dc5c6c3a58d41643..4d668e3b55f9139d8eb6cd3dcefec42efb8f4e9a 100644 (file)
@@ -1048,6 +1048,7 @@ class Mapper(object):
 
                 isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity(obj)
                 params = {}
+                value_params = {}
                 hasdata = False
                 for col in table.columns:
                     if col is mapper.version_id_col:
@@ -1097,7 +1098,10 @@ class Mapper(object):
                             if history:
                                 a = history.added_items()
                                 if len(a):
-                                    params[col.key] = prop.get_col_value(col, a[0])
+                                    if isinstance(a[0], sql.ClauseElement):
+                                        value_params[col] = a[0]
+                                    else:
+                                        params[col.key] = prop.get_col_value(col, a[0])
                                     hasdata = True
                         else:
                             # doing an INSERT, non primary key col ?
@@ -1110,15 +1114,18 @@ class Mapper(object):
                             if value is NO_ATTRIBUTE:
                                 continue
                             if col.default is None or value is not None:
-                                params[col.key] = value
+                                if isinstance(value, sql.ClauseElement):
+                                    value_params[col] = value
+                                else:
+                                    params[col.key] = value
 
                 if not isinsert:
                     if hasdata:
                         # if none of the attributes changed, dont even
                         # add the row to be updated.
-                        update.append((obj, params, mapper, connection))
+                        update.append((obj, params, mapper, connection, value_params))
                 else:
-                    insert.append((obj, params, mapper, connection))
+                    insert.append((obj, params, mapper, connection, value_params))
 
             if len(update):
                 mapper = table_to_mapper[table]
@@ -1138,9 +1145,9 @@ class Mapper(object):
                     return 0
                 update.sort(comparator)
                 for rec in update:
-                    (obj, params, mapper, connection) = rec
-                    c = connection.execute(statement, params)
-                    mapper._postfetch(connection, table, obj, c, c.last_updated_params())
+                    (obj, params, mapper, connection, value_params) = rec
+                    c = connection.execute(statement.values(value_params), params)
+                    mapper._postfetch(connection, table, obj, c, c.last_updated_params(), value_params)
 
                     updated_objects.add((obj, connection))
                     rows += c.rowcount
@@ -1154,8 +1161,8 @@ class Mapper(object):
                     return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order)
                 insert.sort(comparator)
                 for rec in insert:
-                    (obj, params, mapper, connection) = rec
-                    c = connection.execute(statement, params)
+                    (obj, params, mapper, connection, value_params) = rec
+                    c = connection.execute(statement.values(value_params), params)
                     primary_key = c.last_inserted_ids()
                     if primary_key is not None:
                         i = 0
@@ -1163,7 +1170,7 @@ class Mapper(object):
                             if mapper.get_attr_by_column(obj, col) is None and len(primary_key) > i:
                                 mapper.set_attr_by_column(obj, col, primary_key[i])
                             i+=1
-                    mapper._postfetch(connection, table, obj, c, c.last_inserted_params())
+                    mapper._postfetch(connection, table, obj, c, c.last_inserted_params(), value_params)
 
                     # synchronize newly inserted ids from one table to the next
                     # TODO: this fires off more than needed, try to organize syncrules
@@ -1185,22 +1192,23 @@ class Mapper(object):
                 for mapper in object_mapper(obj).iterate_to_root():
                     mapper.extension.after_update(mapper, connection, obj)
 
-    def _postfetch(self, connection, table, obj, resultproxy, params):
+    def _postfetch(self, connection, table, obj, resultproxy, params, value_params):
         """After an ``INSERT`` or ``UPDATE``, assemble newly generated
         values on an instance.  For columns which are marked as being generated
         on the database side, set up a group-based "deferred" loader 
         which will populate those attributes in one query when next accessed.
         """
 
-        postfetch_cols = resultproxy.context.postfetch_cols()
+        postfetch_cols = resultproxy.context.postfetch_cols().union(util.Set(value_params.keys())) 
         deferred_props = []
 
         for c in table.c:
-            if c in postfetch_cols and not c.key in params:
+            if c in postfetch_cols and (not c.key in params or c in value_params):
                 prop = self._getpropbycolumn(c, raiseerror=False)
                 if prop is None:
                     continue
                 deferred_props.append(prop)
+                continue
             if c.primary_key or not c.key in params:
                 continue
             v = self.get_attr_by_column(obj, c, False)
index 3afaef58e581e4ba0d14f59e0c1227908d14f996..8e8a12895de481c9f96b41e665d66b68014ad0c7 100644 (file)
@@ -3380,6 +3380,8 @@ class Insert(_UpdateBase):
         self.parameters = self.parameters.copy()
 
     def values(self, v):
+        if len(v) == 0:
+            return self
         u = self._clone()
         if u.parameters is None:
             u.parameters = u._process_colparams(v)
@@ -3405,6 +3407,8 @@ class Update(_UpdateBase):
         self.parameters = self.parameters.copy()
         
     def values(self, v):
+        if len(v) == 0:
+            return self
         u = self._clone()
         if u.parameters is None:
             u.parameters = u._process_colparams(v)
index a3ee0654e57845630ff42281f38bf7a78e85a8d6..293021b2cfde1c62958a8a50f396ad3fb48884d8 100644 (file)
@@ -454,6 +454,74 @@ class ForeignPKTest(UnitOfWorkTest):
         Session.commit()
         assert people.count(people.c.person=='im the key').scalar() == peoplesites.count(peoplesites.c.person=='im the key').scalar() == 1
 
+class ClauseAttributesTest(UnitOfWorkTest):
+    def setUpAll(self):
+        UnitOfWorkTest.setUpAll(self)
+        global metadata, users_table
+        metadata = MetaData(testbase.db)
+        users_table = Table('users', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(30)),
+            Column('counter', Integer, default=1))
+        metadata.create_all()
+    
+    def tearDown(self):
+        users_table.delete().execute()
+        UnitOfWorkTest.tearDown(self)
+        
+    def tearDownAll(self):
+        metadata.drop_all()
+        UnitOfWorkTest.tearDownAll(self)
+        
+    def test_update(self):
+        class User(object):
+            pass
+        mapper(User, users_table)
+        u = User(name='test')
+        sess = Session()
+        sess.save(u)
+        sess.flush()
+        assert u.counter == 1
+        u.counter = users_table.c.counter + 1
+        sess.flush()
+        def go():
+            assert u.counter == 2
+        self.assert_sql_count(testbase.db, go, 1)
+
+    def test_multi_update(self):
+        class User(object):
+            pass
+        mapper(User, users_table)
+        u = User(name='test')
+        sess = Session()
+        sess.save(u)
+        sess.flush()
+        assert u.counter == 1
+        u.name = 'test2'
+        u.counter = users_table.c.counter + 1
+        sess.flush()
+        def go():
+            assert u.name == 'test2'
+            assert u.counter == 2
+        self.assert_sql_count(testbase.db, go, 1)
+        
+        sess.clear()
+        u = sess.query(User).get(u.id)
+        assert u.name == 'test2'
+        assert u.counter == 2
+    
+    def test_insert(self):
+        class User(object):
+            pass
+        mapper(User, users_table)
+        u = User(name='test', counter=select([5]))
+        sess = Session()
+        sess.save(u)
+        sess.flush()
+        assert u.counter == 5
+        
+
+        
 class PassiveDeletesTest(UnitOfWorkTest):
     def setUpAll(self):
         UnitOfWorkTest.setUpAll(self)