]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- refined mapper._save_obj() which was unnecessarily calling
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Apr 2008 16:13:49 +0000 (16:13 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Apr 2008 16:13:49 +0000 (16:13 +0000)
__ne__() on scalar values during flush [ticket:1015]

CHANGES
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/util.py
test/orm/mapper.py
test/orm/unitofwork.py

diff --git a/CHANGES b/CHANGES
index e53c101de5a7cdc96025d3c3b2181c88c9e28ad1..c23442baf140f096793491c78f3f94e9ddc3644f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -36,6 +36,9 @@ CHANGES
     - fixed Class.collection==None for m2m relationships
       [ticket:4213]
       
+    - refined mapper._save_obj() which was unnecessarily calling
+      __ne__() on scalar values during flush [ticket:1015]
+      
 - sql
     - Added COLLATE support via the .collate(<collation>)
       expression operator and collate(<expr>, <collation>) sql
index 21e0101d224c4d40e36b455d92cc4bda0df1d39f..d7c7cebaa8df5b7393174a09fd6597c6232ca569 100644 (file)
@@ -1574,7 +1574,10 @@ class ResultProxy(object):
         See ExecutionContext for details.
         """
         return self.context.postfetch_cols
-
+    
+    def prefetch_cols(self):
+        return self.context.prefetch_cols
+        
     def supports_sane_rowcount(self):
         """Return ``supports_sane_rowcount`` from the dialect.
 
index 65867d4138f5d139a4c603a6a866e119b8cf7d0b..3c1721f9d9e39ea396127e7e63a6c2ddaa6f2c1e 100644 (file)
@@ -395,3 +395,4 @@ class DefaultExecutionContext(base.ExecutionContext):
                 self._last_updated_params = compiled_parameters
 
             self.postfetch_cols = self.compiled.postfetch
+            self.prefetch_cols = self.compiled.prefetch
\ No newline at end of file
index 8fd26acf19fa1b5a3596c86b282f673d795e6c3a..b1d749d6f8580f10e2fc513989788bfe63c1807d 100644 (file)
@@ -1133,7 +1133,7 @@ class Mapper(object):
                 for rec in update:
                     (state, params, mapper, connection, value_params) = rec
                     c = connection.execute(statement.values(value_params), params)
-                    mapper._postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params)
+                    mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params)
 
                     # testlib.pragma exempt:__hash__
                     updated_objects.add((state, connection))
@@ -1157,14 +1157,14 @@ class Mapper(object):
                         for i, col in enumerate(mapper._pks_by_table[table]):
                             if mapper._get_state_attr_by_column(state, col) is None and len(primary_key) > i:
                                 mapper._set_state_attr_by_column(state, col, primary_key[i])
-                    mapper._postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params)
+                    mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params)
 
                     # synchronize newly inserted ids from one table to the next
                     # TODO: this fires off more than needed, try to organize syncrules
                     # per table
                     for m in util.reversed(list(mapper.iterate_to_root())):
                         if m.__inherits_equated_pairs:
-                            m._synchronize_inherited(state)
+                            m.__synchronize_inherited(state)
 
                     # testlib.pragma exempt:__hash__
                     inserted_objects.add((state, connection))
@@ -1180,26 +1180,32 @@ class Mapper(object):
                     if 'after_update' in mapper.extension.methods:
                         mapper.extension.after_update(mapper, connection, state.obj())
 
-    def _synchronize_inherited(self, state):
+    def __synchronize_inherited(self, state):
         sync.populate(state, self, state, self, self.__inherits_equated_pairs)
 
-    def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
+    def __postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
         """After an ``INSERT`` or ``UPDATE``, assemble newly generated
         values on an instance.  For columns which are marked as being generated
         on the database side, set up a group-based "deferred" loader
         which will populate those attributes in one query when next accessed.
         """
 
-        postfetch_cols = util.Set(resultproxy.postfetch_cols()).union(util.Set(value_params.keys()))
-        deferred_props = []
+        postfetch_cols = resultproxy.postfetch_cols()
+        generated_cols = list(resultproxy.prefetch_cols())
 
-        for c in self._cols_by_table[table]:
-            if c in postfetch_cols and (not c.key in params or c in value_params):
-                prop = self._columntoproperty[c]
-                deferred_props.append(prop.key)
-            elif not c.primary_key and c.key in params and self._get_state_attr_by_column(state, c) != params[c.key]:
+        if self.polymorphic_on:
+            po = table.corresponding_column(self.polymorphic_on)
+            if po:
+                generated_cols.append(po)
+        if self.version_id_col:
+            generated_cols.append(self.version_id_col)
+
+        for c in generated_cols:
+            if c.key in params:
                 self._set_state_attr_by_column(state, c, params[c.key])
 
+        deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]]
+
         if deferred_props:
             if self.eager_defaults:
                 _instance_key = self._identity_key_from_state(state)
index bc02b879e1722f979d07e342cac3b14cc978ecab..e88c4b3b9b176c80d91a8405f969848e804d4d66 100644 (file)
@@ -428,6 +428,7 @@ class SimpleProperty(object):
         else:
             return getattr(obj, self.key)
 
+
 class NotImplProperty(object):
   """a property that raises ``NotImplementedError``."""
 
index 2024cbf6d7cb6b3b68933cbc19cc8590e83637d3..8d69085ae01dbaae56ee41298bd25ffe7e808b7f 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy import exceptions, sql
 from sqlalchemy.orm import *
 from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt
 from testlib import *
+from testlib import fixtures
 from testlib.tables import *
 import testlib.tables as tables
 
@@ -1482,13 +1483,12 @@ class MapperExtensionTest(TestBase):
             'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete']
             )
 
-class RequirementsTest(TestBase, AssertsExecutionResults):
+class RequirementsTest(ORMTest):
     """Tests the contract for user classes."""
 
-    def setUpAll(self):
-        global metadata, t1, t2, t3, t4, t5, t6
+    def define_tables(self, metadata):
+        global t1, t2, t3, t4, t5, t6
 
-        metadata = MetaData(testing.db)
         t1 = Table('ht1', metadata,
                    Column('id', Integer, primary_key=True),
                    Column('value', String(10)))
@@ -1514,13 +1514,6 @@ class RequirementsTest(TestBase, AssertsExecutionResults):
                    Column('ht1b_id', Integer, ForeignKey('ht1.id'),
                           primary_key=True),
                    Column('value', String(10)))
-        metadata.create_all()
-
-    def setUp(self):
-        clear_mappers()
-
-    def tearDownAll(self):
-        metadata.drop_all()
 
     def test_baseclass(self):
         class OldStyle:
@@ -1591,6 +1584,7 @@ class RequirementsTest(TestBase, AssertsExecutionResults):
                     return self.value == other.value
                 return False
 
+                
         mapper(H1, t1, properties={
             'h2s': relation(H2, backref='h1'),
             'h3s': relation(H3, secondary=t4, backref='h1s'),
@@ -1652,6 +1646,37 @@ class RequirementsTest(TestBase, AssertsExecutionResults):
                                   eagerload_all('h3s.h1s')).all()
         self.assertEqual(len(h1s), 5)
 
+class NoEqFoo(object):
+    def __init__(self, data):
+        self.data = data
+    def __eq__(self, other):
+        raise NotImplementedError()
+    def __ne__(self, other):
+        raise NotImplementedError()
+
+class ScalarRequirementsTest(ORMTest):
+    def define_tables(self, metadata):
+        import pickle
+        global t1
+        t1 = Table('t1', metadata, Column('id', Integer, primary_key=True),
+            Column('data', PickleType(pickler=pickle))  # dont use cPickle due to import weirdness
+        )
+        
+    def test_correct_comparison(self):
+                
+        class H1(fixtures.Base):
+            pass
+            
+        mapper(H1, t1)
+        
+        h1 = H1(data=NoEqFoo('12345'))
+        s = create_session()
+        s.save(h1)
+        s.flush()
+        s.clear()
+        h1 = s.get(H1, h1.id)
+        assert h1.data.data == '12345'
+        
 
 if __name__ == "__main__":
     testenv.main()
index 3867290c8e651a0a4f98671350b9418cacc02465..cd2a3005ea4aa135feb3dfc0d9d23e7479491fde 100644 (file)
@@ -524,8 +524,9 @@ class ClauseAttributesTest(ORMTest):
         assert u.counter == 1
         u.counter = User.counter + 1
         sess.flush()
+
         def go():
-            assert u.counter == 2
+            assert (u.counter == 2) is True  # ensure its not a ClauseElement
         self.assert_sql_count(testing.db, go, 1)
 
     def test_multi_update(self):
@@ -542,7 +543,7 @@ class ClauseAttributesTest(ORMTest):
         sess.flush()
         def go():
             assert u.name == 'test2'
-            assert u.counter == 2
+            assert (u.counter == 2) is True
         self.assert_sql_count(testing.db, go, 1)
 
         sess.clear()
@@ -559,7 +560,7 @@ class ClauseAttributesTest(ORMTest):
         sess = Session()
         sess.save(u)
         sess.flush()
-        assert u.counter == 5
+        assert (u.counter == 5) is True
 
 
 class PassiveDeletesTest(ORMTest):