]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added a mapper() flag "eager_defaults"; when set to
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Jan 2008 22:32:51 +0000 (22:32 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Jan 2008 22:32:51 +0000 (22:32 +0000)
True, defaults that are generated during an INSERT
or UPDATE operation are post-fetched immediately,
instead of being deferred until later.  This mimics
the old 0.3 behavior.

CHANGES
lib/sqlalchemy/orm/mapper.py
test/sql/defaults.py

diff --git a/CHANGES b/CHANGES
index c0a63a546d2365970e81b97b8186aa6ebf76142e..32970990986764b911c178ea61882429e8ea2976 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -8,6 +8,12 @@ CHANGES
     - proper error message is raised when trying to 
       access expired instance attributes with no session
       present
+    
+    - added a mapper() flag "eager_defaults"; when set to
+      True, defaults that are generated during an INSERT
+      or UPDATE operation are post-fetched immediately, 
+      instead of being deferred until later.  This mimics
+      the old 0.3 behavior.
       
 - dialects
     - finally added PGMacAddr type to postgres 
index dcfae524f6b6866c9f3ea9615999e44ddc1e52d2..1451f4336c72204e6f13c4e3ac57a695b4c40596 100644 (file)
@@ -72,7 +72,8 @@ class Mapper(object):
                  batch=True,
                  column_prefix=None,
                  include_properties=None,
-                 exclude_properties=None):
+                 exclude_properties=None,
+                 eager_defaults=False):
         """Construct a new mapper.
 
         Mappers are normally constructed via the [sqlalchemy.orm#mapper()] 
@@ -111,6 +112,7 @@ class Mapper(object):
         self.allow_null_pks = allow_null_pks
         self.delete_orphans = []
         self.batch = batch
+        self.eager_defaults = eager_defaults
         self.column_prefix = column_prefix
         self.polymorphic_on = polymorphic_on
         self._eager_loaders = util.Set()
@@ -1086,7 +1088,7 @@ class Mapper(object):
                 for rec in update:
                     (state, params, mapper, connection, value_params) = rec
                     c = connection.execute(statement.values(value_params), params)
-                    mapper._postfetch(connection, table, state, c, c.last_updated_params(), value_params)
+                    mapper._postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params)
 
                     # testlib.pragma exempt:__hash__
                     updated_objects.add((state, connection))
@@ -1110,7 +1112,7 @@ class Mapper(object):
                         for i, col in enumerate(mapper._pks_by_table[table]):
                             if mapper._get_state_attr_by_column(state, col) is None and len(primary_key) > i:
                                 mapper._set_state_attr_by_column(state, col, primary_key[i])
-                    mapper._postfetch(connection, table, state, c, c.last_inserted_params(), value_params)
+                    mapper._postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params)
 
                     # synchronize newly inserted ids from one table to the next
                     # TODO: this fires off more than needed, try to organize syncrules
@@ -1133,7 +1135,7 @@ class Mapper(object):
                     if 'after_update' in mapper.extension.methods:
                         mapper.extension.after_update(mapper, connection, state.obj())
     
-    def _postfetch(self, connection, table, state, resultproxy, params, value_params):
+    def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
         """After an ``INSERT`` or ``UPDATE``, assemble newly generated
         values on an instance.  For columns which are marked as being generated
         on the database side, set up a group-based "deferred" loader 
@@ -1151,7 +1153,13 @@ class Mapper(object):
                 self._set_state_attr_by_column(state, c, params[c.key])
         
         if deferred_props:
-            _expire_state(state, deferred_props)
+            # TODO: need a unit test for this functionality
+            if self.eager_defaults:
+                _instance_key = self._identity_key_from_state(state)
+                state.dict['_instance_key'] = _instance_key
+                uowtransaction.session.query(self)._get(_instance_key, refresh_instance=state, only_load_props=deferred_props)
+            else:
+                _expire_state(state, deferred_props)
 
     def _delete_obj(self, states, uowtransaction):
         """Issue ``DELETE`` statements for a list of objects.
@@ -1310,7 +1318,14 @@ class Mapper(object):
             
             if not currentload and context.version_check and self.version_id_col and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]:
                 raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col]))
-            
+        elif refresh_instance:
+            # out of band refresh_instance detected (i.e. its not in the session.identity_map)
+            # honor it anyway.  this can happen if a _get() occurs within save_obj(), such as
+            # when eager_defaults is True.
+            state = refresh_instance
+            instance = state.obj()
+            isnew = state.runid != context.runid
+            currentload = True
         else:
             if self.__should_log_debug:
                 self.__log_debug("_instance(): identity key %s not in session" % str(identitykey))
@@ -1526,12 +1541,12 @@ def has_mapper(object):
 object_session = None
 
 def _load_scalar_attributes(instance, attribute_names):
+    mapper = object_mapper(instance)
+
     global object_session
     if not object_session:
         from sqlalchemy.orm.session import object_session
-
     session = object_session(instance)
-    mapper = object_mapper(instance)
     if not session:
         try:
             session = mapper.get_session()
index 98b379995d3e43b78edd57cb26c32c12e24dda8d..a41ef4a1721c568d4dbd73f4ffcef4350c7cd9a3 100644 (file)
@@ -407,8 +407,12 @@ class SequenceTest(PersistTest):
     def testseqnonpk(self):
         """test sequences fire off as defaults on non-pk columns"""
 
-        sometable.insert().execute(name="somename")
-        sometable.insert().execute(name="someother")
+        result = sometable.insert().execute(name="somename")
+        assert 'id' in result.postfetch_cols()
+        
+        result = sometable.insert().execute(name="someother")
+        assert 'id' in result.postfetch_cols()
+
         sometable.insert().execute(
             {'name':'name3'},
             {'name':'name4'}