From: Mike Bayer Date: Sat, 26 Apr 2008 16:13:49 +0000 (+0000) Subject: - refined mapper._save_obj() which was unnecessarily calling X-Git-Tag: rel_0_5beta1~173 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b089e8615b9cec8b7cb4741b1fad8c30afcfc848;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - refined mapper._save_obj() which was unnecessarily calling __ne__() on scalar values during flush [ticket:1015] --- diff --git a/CHANGES b/CHANGES index e53c101de5..c23442baf1 100644 --- a/CHANGES +++ b/CHANGES @@ -36,6 +36,9 @@ CHANGES - fixed Class.collection==None for m2m relationships [ticket:4213] + - refined mapper._save_obj() which was unnecessarily calling + __ne__() on scalar values during flush [ticket:1015] + - sql - Added COLLATE support via the .collate() expression operator and collate(, ) sql diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 21e0101d22..d7c7cebaa8 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1574,7 +1574,10 @@ class ResultProxy(object): See ExecutionContext for details. """ return self.context.postfetch_cols - + + def prefetch_cols(self): + return self.context.prefetch_cols + def supports_sane_rowcount(self): """Return ``supports_sane_rowcount`` from the dialect. diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 65867d4138..3c1721f9d9 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -395,3 +395,4 @@ class DefaultExecutionContext(base.ExecutionContext): self._last_updated_params = compiled_parameters self.postfetch_cols = self.compiled.postfetch + self.prefetch_cols = self.compiled.prefetch \ No newline at end of file diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 8fd26acf19..b1d749d6f8 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1133,7 +1133,7 @@ class Mapper(object): for rec in update: (state, params, mapper, connection, value_params) = rec c = connection.execute(statement.values(value_params), params) - mapper._postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params) + mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params) # testlib.pragma exempt:__hash__ updated_objects.add((state, connection)) @@ -1157,14 +1157,14 @@ class Mapper(object): for i, col in enumerate(mapper._pks_by_table[table]): if mapper._get_state_attr_by_column(state, col) is None and len(primary_key) > i: mapper._set_state_attr_by_column(state, col, primary_key[i]) - mapper._postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params) + mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params) # synchronize newly inserted ids from one table to the next # TODO: this fires off more than needed, try to organize syncrules # per table for m in util.reversed(list(mapper.iterate_to_root())): if m.__inherits_equated_pairs: - m._synchronize_inherited(state) + m.__synchronize_inherited(state) # testlib.pragma exempt:__hash__ inserted_objects.add((state, connection)) @@ -1180,26 +1180,32 @@ class Mapper(object): if 'after_update' in mapper.extension.methods: mapper.extension.after_update(mapper, connection, state.obj()) - def _synchronize_inherited(self, state): + def __synchronize_inherited(self, state): sync.populate(state, self, state, self, self.__inherits_equated_pairs) - def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params): + def __postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params): """After an ``INSERT`` or ``UPDATE``, assemble newly generated values on an instance. For columns which are marked as being generated on the database side, set up a group-based "deferred" loader which will populate those attributes in one query when next accessed. """ - postfetch_cols = util.Set(resultproxy.postfetch_cols()).union(util.Set(value_params.keys())) - deferred_props = [] + postfetch_cols = resultproxy.postfetch_cols() + generated_cols = list(resultproxy.prefetch_cols()) - for c in self._cols_by_table[table]: - if c in postfetch_cols and (not c.key in params or c in value_params): - prop = self._columntoproperty[c] - deferred_props.append(prop.key) - elif not c.primary_key and c.key in params and self._get_state_attr_by_column(state, c) != params[c.key]: + if self.polymorphic_on: + po = table.corresponding_column(self.polymorphic_on) + if po: + generated_cols.append(po) + if self.version_id_col: + generated_cols.append(self.version_id_col) + + for c in generated_cols: + if c.key in params: self._set_state_attr_by_column(state, c, params[c.key]) + deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]] + if deferred_props: if self.eager_defaults: _instance_key = self._identity_key_from_state(state) diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index bc02b879e1..e88c4b3b9b 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -428,6 +428,7 @@ class SimpleProperty(object): else: return getattr(obj, self.key) + class NotImplProperty(object): """a property that raises ``NotImplementedError``.""" diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 2024cbf6d7..8d69085ae0 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -6,6 +6,7 @@ from sqlalchemy import exceptions, sql from sqlalchemy.orm import * from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt from testlib import * +from testlib import fixtures from testlib.tables import * import testlib.tables as tables @@ -1482,13 +1483,12 @@ class MapperExtensionTest(TestBase): 'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete'] ) -class RequirementsTest(TestBase, AssertsExecutionResults): +class RequirementsTest(ORMTest): """Tests the contract for user classes.""" - def setUpAll(self): - global metadata, t1, t2, t3, t4, t5, t6 + def define_tables(self, metadata): + global t1, t2, t3, t4, t5, t6 - metadata = MetaData(testing.db) t1 = Table('ht1', metadata, Column('id', Integer, primary_key=True), Column('value', String(10))) @@ -1514,13 +1514,6 @@ class RequirementsTest(TestBase, AssertsExecutionResults): Column('ht1b_id', Integer, ForeignKey('ht1.id'), primary_key=True), Column('value', String(10))) - metadata.create_all() - - def setUp(self): - clear_mappers() - - def tearDownAll(self): - metadata.drop_all() def test_baseclass(self): class OldStyle: @@ -1591,6 +1584,7 @@ class RequirementsTest(TestBase, AssertsExecutionResults): return self.value == other.value return False + mapper(H1, t1, properties={ 'h2s': relation(H2, backref='h1'), 'h3s': relation(H3, secondary=t4, backref='h1s'), @@ -1652,6 +1646,37 @@ class RequirementsTest(TestBase, AssertsExecutionResults): eagerload_all('h3s.h1s')).all() self.assertEqual(len(h1s), 5) +class NoEqFoo(object): + def __init__(self, data): + self.data = data + def __eq__(self, other): + raise NotImplementedError() + def __ne__(self, other): + raise NotImplementedError() + +class ScalarRequirementsTest(ORMTest): + def define_tables(self, metadata): + import pickle + global t1 + t1 = Table('t1', metadata, Column('id', Integer, primary_key=True), + Column('data', PickleType(pickler=pickle)) # dont use cPickle due to import weirdness + ) + + def test_correct_comparison(self): + + class H1(fixtures.Base): + pass + + mapper(H1, t1) + + h1 = H1(data=NoEqFoo('12345')) + s = create_session() + s.save(h1) + s.flush() + s.clear() + h1 = s.get(H1, h1.id) + assert h1.data.data == '12345' + if __name__ == "__main__": testenv.main() diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 3867290c8e..cd2a3005ea 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -524,8 +524,9 @@ class ClauseAttributesTest(ORMTest): assert u.counter == 1 u.counter = User.counter + 1 sess.flush() + def go(): - assert u.counter == 2 + assert (u.counter == 2) is True # ensure its not a ClauseElement self.assert_sql_count(testing.db, go, 1) def test_multi_update(self): @@ -542,7 +543,7 @@ class ClauseAttributesTest(ORMTest): sess.flush() def go(): assert u.name == 'test2' - assert u.counter == 2 + assert (u.counter == 2) is True self.assert_sql_count(testing.db, go, 1) sess.clear() @@ -559,7 +560,7 @@ class ClauseAttributesTest(ORMTest): sess = Session() sess.save(u) sess.flush() - assert u.counter == 5 + assert (u.counter == 5) is True class PassiveDeletesTest(ORMTest):