From: Mike Bayer Date: Thu, 2 Aug 2007 04:24:02 +0000 (+0000) Subject: - added inline UPDATE/INSERT clauses, settable as regular object attributes. X-Git-Tag: rel_0_4beta1~106 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9f23ec7423e98305f43a0b7a7ef894da74325329;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - added inline UPDATE/INSERT clauses, settable as regular object attributes. the clause gets executed inline during a flush(). --- diff --git a/CHANGES b/CHANGES index 9fcdb3c062..2b4f9d4829 100644 --- a/CHANGES +++ b/CHANGES @@ -94,6 +94,9 @@ created/mapped to a single attribute, comprised of the values correponding to *columns [ticket:211] + - added inline UPDATE/INSERT clauses, settable as regular object attributes. + the clause gets executed inline during a flush(). + - improved support for custom column_property() attributes which feature correlated subqueries...work better with eager loading now. diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index d50fb5a815..acff873325 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -607,7 +607,7 @@ class AttributeHistory(object): self._deleted_items.append(a) else: self._current = [current] - if attr.is_equal(current, original): + if attr.is_equal(current, original) is True: self._unchanged_items = [current] self._added_items = [] self._deleted_items = [] diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index af7c9d4cff..4d668e3b55 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1048,6 +1048,7 @@ class Mapper(object): isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity(obj) params = {} + value_params = {} hasdata = False for col in table.columns: if col is mapper.version_id_col: @@ -1097,7 +1098,10 @@ class Mapper(object): if history: a = history.added_items() if len(a): - params[col.key] = prop.get_col_value(col, a[0]) + if isinstance(a[0], sql.ClauseElement): + value_params[col] = a[0] + else: + params[col.key] = prop.get_col_value(col, a[0]) hasdata = True else: # doing an INSERT, non primary key col ? @@ -1110,15 +1114,18 @@ class Mapper(object): if value is NO_ATTRIBUTE: continue if col.default is None or value is not None: - params[col.key] = value + if isinstance(value, sql.ClauseElement): + value_params[col] = value + else: + params[col.key] = value if not isinsert: if hasdata: # if none of the attributes changed, dont even # add the row to be updated. - update.append((obj, params, mapper, connection)) + update.append((obj, params, mapper, connection, value_params)) else: - insert.append((obj, params, mapper, connection)) + insert.append((obj, params, mapper, connection, value_params)) if len(update): mapper = table_to_mapper[table] @@ -1138,9 +1145,9 @@ class Mapper(object): return 0 update.sort(comparator) for rec in update: - (obj, params, mapper, connection) = rec - c = connection.execute(statement, params) - mapper._postfetch(connection, table, obj, c, c.last_updated_params()) + (obj, params, mapper, connection, value_params) = rec + c = connection.execute(statement.values(value_params), params) + mapper._postfetch(connection, table, obj, c, c.last_updated_params(), value_params) updated_objects.add((obj, connection)) rows += c.rowcount @@ -1154,8 +1161,8 @@ class Mapper(object): return cmp(a[0]._sa_insert_order, b[0]._sa_insert_order) insert.sort(comparator) for rec in insert: - (obj, params, mapper, connection) = rec - c = connection.execute(statement, params) + (obj, params, mapper, connection, value_params) = rec + c = connection.execute(statement.values(value_params), params) primary_key = c.last_inserted_ids() if primary_key is not None: i = 0 @@ -1163,7 +1170,7 @@ class Mapper(object): if mapper.get_attr_by_column(obj, col) is None and len(primary_key) > i: mapper.set_attr_by_column(obj, col, primary_key[i]) i+=1 - mapper._postfetch(connection, table, obj, c, c.last_inserted_params()) + mapper._postfetch(connection, table, obj, c, c.last_inserted_params(), value_params) # synchronize newly inserted ids from one table to the next # TODO: this fires off more than needed, try to organize syncrules @@ -1185,22 +1192,23 @@ class Mapper(object): for mapper in object_mapper(obj).iterate_to_root(): mapper.extension.after_update(mapper, connection, obj) - def _postfetch(self, connection, table, obj, resultproxy, params): + def _postfetch(self, connection, table, obj, resultproxy, params, value_params): """After an ``INSERT`` or ``UPDATE``, assemble newly generated values on an instance. For columns which are marked as being generated on the database side, set up a group-based "deferred" loader which will populate those attributes in one query when next accessed. """ - postfetch_cols = resultproxy.context.postfetch_cols() + postfetch_cols = resultproxy.context.postfetch_cols().union(util.Set(value_params.keys())) deferred_props = [] for c in table.c: - if c in postfetch_cols and not c.key in params: + if c in postfetch_cols and (not c.key in params or c in value_params): prop = self._getpropbycolumn(c, raiseerror=False) if prop is None: continue deferred_props.append(prop) + continue if c.primary_key or not c.key in params: continue v = self.get_attr_by_column(obj, c, False) diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 3afaef58e5..8e8a12895d 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -3380,6 +3380,8 @@ class Insert(_UpdateBase): self.parameters = self.parameters.copy() def values(self, v): + if len(v) == 0: + return self u = self._clone() if u.parameters is None: u.parameters = u._process_colparams(v) @@ -3405,6 +3407,8 @@ class Update(_UpdateBase): self.parameters = self.parameters.copy() def values(self, v): + if len(v) == 0: + return self u = self._clone() if u.parameters is None: u.parameters = u._process_colparams(v) diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index a3ee0654e5..293021b2cf 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -454,6 +454,74 @@ class ForeignPKTest(UnitOfWorkTest): Session.commit() assert people.count(people.c.person=='im the key').scalar() == peoplesites.count(peoplesites.c.person=='im the key').scalar() == 1 +class ClauseAttributesTest(UnitOfWorkTest): + def setUpAll(self): + UnitOfWorkTest.setUpAll(self) + global metadata, users_table + metadata = MetaData(testbase.db) + users_table = Table('users', metadata, + Column('id', Integer, primary_key=True), + Column('name', String(30)), + Column('counter', Integer, default=1)) + metadata.create_all() + + def tearDown(self): + users_table.delete().execute() + UnitOfWorkTest.tearDown(self) + + def tearDownAll(self): + metadata.drop_all() + UnitOfWorkTest.tearDownAll(self) + + def test_update(self): + class User(object): + pass + mapper(User, users_table) + u = User(name='test') + sess = Session() + sess.save(u) + sess.flush() + assert u.counter == 1 + u.counter = users_table.c.counter + 1 + sess.flush() + def go(): + assert u.counter == 2 + self.assert_sql_count(testbase.db, go, 1) + + def test_multi_update(self): + class User(object): + pass + mapper(User, users_table) + u = User(name='test') + sess = Session() + sess.save(u) + sess.flush() + assert u.counter == 1 + u.name = 'test2' + u.counter = users_table.c.counter + 1 + sess.flush() + def go(): + assert u.name == 'test2' + assert u.counter == 2 + self.assert_sql_count(testbase.db, go, 1) + + sess.clear() + u = sess.query(User).get(u.id) + assert u.name == 'test2' + assert u.counter == 2 + + def test_insert(self): + class User(object): + pass + mapper(User, users_table) + u = User(name='test', counter=select([5])) + sess = Session() + sess.save(u) + sess.flush() + assert u.counter == 5 + + + class PassiveDeletesTest(UnitOfWorkTest): def setUpAll(self): UnitOfWorkTest.setUpAll(self)