]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- clarified LoaderStrategy implementations, centralized deferred column loading
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Jul 2007 18:57:02 +0000 (18:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 Jul 2007 18:57:02 +0000 (18:57 +0000)
into DeferredColumnLoader (i.e. deferred polymorphic loader)
- added generic deferred_load(instance, props) method, will set up "deferred" or "lazy"
loads across a set of properties.
- mapper post-fetch now uses all deferreds, no more post-selects inside a flush() [ticket:652]

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/strategies.py
test/orm/unitofwork.py

index fc4433a47cae8900c780ed616691ed9e6c8f0c3f..ca602b58cd8541874a6c380790b3ccac66896592 100644 (file)
@@ -300,6 +300,11 @@ class ExecutionContext(object):
             (i.e. dict or list of dicts for non positional,
             list or list of lists/tuples for positional).
             
+        isinsert
+          True if the statement is an INSERT
+            
+        isupdate
+          True if the statement is an UPDATE
     
     The Dialect should provide an ExecutionContext via the
     create_execution_context() method.  The `pre_exec` and `post_exec`
@@ -388,6 +393,12 @@ class ExecutionContext(object):
 
         raise NotImplementedError()
 
+    def postfetch_cols(self):
+        """return a list of Column objects for which a 'passive' server-side default
+        value was fired off"""
+
+        raise NotImplementedError()
+
 class Compiled(object):
     """Represent a compiled SQL expression.
 
@@ -1215,6 +1226,7 @@ class ResultProxy(object):
 
         return self.context.lastrow_has_defaults()
 
+        
     def supports_sane_rowcount(self):
         """Return ``supports_sane_rowcount()`` from the underlying ExecutionContext.
 
index 962e2ab606da67b5dad817381627e696830e50c3..a2e159639dcd9c4cd96c54960f9639427d0c4a5e 100644 (file)
@@ -6,7 +6,7 @@
 
 """Provide default implementations of per-dialect sqlalchemy.engine classes"""
 
-from sqlalchemy import schema, exceptions, sql, types
+from sqlalchemy import schema, exceptions, sql, types, util
 import sys, re
 from sqlalchemy.engine import base
 
@@ -147,6 +147,7 @@ class DefaultExecutionContext(base.ExecutionContext):
         self.dialect = dialect
         self.connection = connection
         self.compiled = compiled
+        self._postfetch_cols = util.Set()
         
         if compiled is not None:
             self.typemap = compiled.typemap
@@ -173,6 +174,8 @@ class DefaultExecutionContext(base.ExecutionContext):
         self.cursor = self.create_cursor()
         
     engine = property(lambda s:s.connection.engine)
+    isinsert = property(lambda s:s.compiled and s.compiled.isinsert)
+    isupdate = property(lambda s:s.compiled and s.compiled.isupdate)
     
     def __encode_param_keys(self, params):
         """apply string encoding to the keys of dictionary-based bind parameters"""
@@ -255,8 +258,11 @@ class DefaultExecutionContext(base.ExecutionContext):
         return self._last_updated_params
 
     def lastrow_has_defaults(self):
-        return self._lastrow_has_defaults
+        return len(self._postfetch_cols)
 
+    def postfetch_cols(self):
+        return self._postfetch_cols
+        
     def set_input_sizes(self):
         """Given a cursor and ClauseParameters, call the appropriate
         style of ``setinputsizes()`` on the cursor, using DBAPI types
@@ -291,13 +297,12 @@ class DefaultExecutionContext(base.ExecutionContext):
         and generate last_inserted_ids() collection."""
 
         # TODO: cleanup
-        if self.compiled.isinsert:
+        if self.isinsert:
             if isinstance(self.compiled_parameters, list):
                 plist = self.compiled_parameters
             else:
                 plist = [self.compiled_parameters]
             drunner = self.dialect.defaultrunner(self)
-            self._lastrow_has_defaults = False
             for param in plist:
                 last_inserted_ids = []
                 # check the "default" status of each column in the table
@@ -305,7 +310,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                     # check if it will be populated by a SQL clause - we'll need that
                     # after execution.
                     if c in self.compiled.inline_params:
-                        self._lastrow_has_defaults = True
+                        self._postfetch_cols.add(c)
                         if c.primary_key:
                             last_inserted_ids.append(None)
                     # check if its not present at all.  see if theres a default
@@ -315,7 +320,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                     # the SQL-generated value after execution.
                     elif not c.key in param or param.get_original(c.key) is None:
                         if isinstance(c.default, schema.PassiveDefault):
-                            self._lastrow_has_defaults = True
+                            self._postfetch_cols.add(c)
                         newid = drunner.get_column_default(c)
                         if newid is not None:
                             param.set_value(c.key, newid)
@@ -331,20 +336,19 @@ class DefaultExecutionContext(base.ExecutionContext):
                 # here (hard to do since lastrowid doesnt support it either)
                 self._last_inserted_ids = last_inserted_ids
                 self._last_inserted_params = param
-        elif self.compiled.isupdate:
+        elif self.isupdate:
             if isinstance(self.compiled_parameters, list):
                 plist = self.compiled_parameters
             else:
                 plist = [self.compiled_parameters]
             drunner = self.dialect.defaultrunner(self)
-            self._lastrow_has_defaults = False
             for param in plist:
                 # check the "onupdate" status of each column in the table
                 for c in self.compiled.statement.table.c:
                     # it will be populated by a SQL clause - we'll need that
                     # after execution.
                     if c in self.compiled.inline_params:
-                        pass
+                        self._postfetch_cols.add(c)
                     # its not in the bind parameters, and theres an "onupdate" defined for the column;
                     # execute it and add to bind params
                     elif c.onupdate is not None and (not c.key in param or param.get_original(c.key) is None):
index 47ff260853d570f4ab9523e2eb54cf8a0435f626..1b081910f5d26ab50ff6662dc5c39dd6927922a0 100644 (file)
@@ -134,7 +134,7 @@ class InstrumentedAttribute(interfaces.PropComparator):
             return None
         return AttributeHistory(self, obj, current, passive=passive)
 
-    def set_callable(self, obj, callable_):
+    def set_callable(self, obj, callable_, clear=False):
         """Set a callable function for this attribute on the given object.
 
         This callable will be executed when the attribute is next
@@ -149,6 +149,9 @@ class InstrumentedAttribute(interfaces.PropComparator):
         ``InstrumentedAttribute` constructor.
         """
 
+        if clear:
+            self.clear(obj)
+            
         if callable_ is None:
             self.initialize(obj)
         else:
@@ -815,14 +818,14 @@ class AttributeManager(object):
         """
         return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute)
 
-    def init_instance_attribute(self, obj, key, callable_=None):
+    def init_instance_attribute(self, obj, key, callable_=None, clear=False):
         """Initialize an attribute on an instance to either a blank
         value, cancelling out any class- or instance-level callables
         that were present, or if a `callable` is supplied set the
         callable to be invoked when the attribute is next accessed.
         """
 
-        getattr(obj.__class__, key).set_callable(obj, callable_)
+        getattr(obj.__class__, key).set_callable(obj, callable_, clear=clear)
 
     def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs):
         """Create a scalar property object, defaulting to
index aeb8a23fa114e3bec963b4fd9f51c4bfd018eec6..655ad4aa694e6b9e0505f5d00374cfa73cd364e1 100644 (file)
@@ -413,7 +413,6 @@ class StrategizedProperty(MapperProperty):
         except KeyError:
             # cache the located strategy per class for faster re-lookup
             strategy = cls(self)
-            strategy.is_default = False
             strategy.init()
             self._all_strategies[cls] = strategy
             return strategy
@@ -631,7 +630,7 @@ class LoaderStrategy(object):
 
     def __init__(self, parent):
         self.parent_property = parent
-        self.is_default = True
+        self.is_class_level = False
 
     def init(self):
         self.parent = self.parent_property.parent
index 92b186012ac3b41bbfd0c7bec219a722cc46d4b2..f63d9fd2bbf07c7ac5c1830a6f5b9ec145212c4e 100644 (file)
@@ -1168,30 +1168,31 @@ class Mapper(object):
                     mapper.extension.after_update(mapper, connection, obj)
 
     def _postfetch(self, connection, table, obj, resultproxy, params):
-        """After an ``INSERT`` or ``UPDATE``, ask the returned result
-        if ``PassiveDefaults`` fired off on the database side which
-        need to be post-fetched, **or** if pre-exec defaults like
-        ``ColumnDefaults`` were fired off and should be populated into
-        the instance. this is only for non-primary key columns.
+        """After an ``INSERT`` or ``UPDATE``, assemble newly generated
+        values on an instance.  For columns which are marked as being generated
+        on the database side, set up a group-based "deferred" loader 
+        which will populate those attributes in one query when next accessed.
         """
 
-        if resultproxy.lastrow_has_defaults():
-            clause = sql.and_()
-            for p in self.pks_by_table[table]:
-                clause.clauses.append(p == self.get_attr_by_column(obj, p))
-            row = connection.execute(table.select(clause), None).fetchone()
-            for c in table.c:
-                if self.get_attr_by_column(obj, c, False) is None:
-                    self.set_attr_by_column(obj, c, row[c])
-        else:
-            for c in table.c:
-                if c.primary_key or not c.key in params:
-                    continue
-                v = self.get_attr_by_column(obj, c, False)
-                if v is NO_ATTRIBUTE:
+        postfetch_cols = resultproxy.context.postfetch_cols()
+        deferred_props = []
+
+        for c in table.c:
+            if c in postfetch_cols and not c.key in params:
+                prop = self._getpropbycolumn(c, raiseerror=False)
+                if prop is None:
                     continue
-                elif v != params.get_original(c.key):
-                    self.set_attr_by_column(obj, c, params.get_original(c.key))
+                deferred_props.append(prop)
+            if c.primary_key or not c.key in params:
+                continue
+            v = self.get_attr_by_column(obj, c, False)
+            if v is NO_ATTRIBUTE:
+                continue
+            elif v != params.get_original(c.key):
+                self.set_attr_by_column(obj, c, params.get_original(c.key))
+        
+        if len(deferred_props):
+            deferred_load(obj, props=deferred_props)
 
     def delete_obj(self, objects, uowtransaction):
         """Issue ``DELETE`` statements for a list of objects.
index 6ce9fd7069a5404b9732726b36565ed81343fcb3..5b0592dd6ea8533bd81afeee2137d9dfa6b2de12 100644 (file)
@@ -80,7 +80,6 @@ class ColumnProperty(StrategizedProperty):
             
 ColumnProperty.logger = logging.class_logger(ColumnProperty)
 
-mapper.ColumnProperty = ColumnProperty
 
 class CompositeProperty(ColumnProperty):
     """subclasses ColumnProperty to provide composite type support."""
@@ -293,6 +292,7 @@ class PropertyLoader(StrategizedProperty):
                 if obj is not None:
                     setattr(dest, self.key, obj)
 
+
     def cascade_iterator(self, type, object, recursive, halt_on=None):
         if not type in self.cascade:
             return
@@ -684,3 +684,29 @@ class BackRef(object):
         """Return an attribute extension to use with this backreference."""
 
         return attributes.GenericBackrefExtension(self.key)
+
+def deferred_load(instance, props):
+    """set multiple instance attributes to 'deferred' or 'lazy' load, for the given set of MapperProperty objects.
+
+    this will remove the current value of the attribute and set a per-instance
+    callable to fire off when the instance is next accessed.
+    
+    for column-based properties, aggreagtes them into a single list against a single deferred loader
+    so that a single column access loads all columns
+
+    """
+
+    if not len(props):
+        return
+    column_props = [p for p in props if isinstance(p, ColumnProperty)]
+    callable_ = column_props[0]._get_strategy(strategies.DeferredColumnLoader).setup_loader(instance, props=column_props)
+    for p in column_props:
+        sessionlib.attribute_manager.init_instance_attribute(instance, p.key, callable_=callable_, clear=True)
+        
+    for p in [p for p in props if isinstance(p, PropertyLoader)]:
+        callable_ = p._get_strategy(strategies.LazyLoader).setup_loader(instance)
+        sessionlib.attribute_manager.init_instance_attribute(instance, p.key, callable_=callable_, clear=True)
+
+mapper.ColumnProperty = ColumnProperty
+mapper.deferred_load = deferred_load
+        
index babd6e4c09860174351f601aa8afdbbff44165d2..501926d499a4513ccc782c1b0d22eb07ae8d9fd1 100644 (file)
@@ -28,6 +28,7 @@ class ColumnLoader(LoaderStrategy):
                 context.statement.append_column(c)
         
     def init_class_attribute(self):
+        self.is_class_level = True
         if self.is_composite:
             self._init_composite_attribute()
         else:
@@ -73,47 +74,38 @@ class ColumnLoader(LoaderStrategy):
             self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key))
             return (execute, None)
 
+        # our mapped column is not present in the row.  check if we need to initialize a polymorphic
+        # row fetcher used by inheritance.
         (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', mapper), (None, None))
         if hosted_mapper is None:
             return (None, None)
         
         if hosted_mapper.polymorphic_fetch == 'deferred':
+            # 'deferred' polymorphic row fetcher, put a callable on the property.
             def execute(instance, row, isnew, **flags):
                 if isnew:
-                    sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_loader(instance, mapper, needs_tables))
+                    sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_inheritance_loader(instance, mapper, needs_tables))
             self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key))
             return (execute, None)
         else:  
+            # immediate polymorphic row fetcher.  no processing needed for this row.
             self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key))
             return (None, None)
 
-    def _get_deferred_loader(self, instance, mapper, needs_tables):
-        def load():
-            group = [p for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables]
-
-            if self._should_log_debug:
-                self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join([p.key for p in group]) or 'None'))
-
-            session = sessionlib.object_session(instance)
-            if session is None:
-                raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
-
+    def _get_deferred_inheritance_loader(self, instance, mapper, needs_tables):
+        def create_statement():
             cond, param_names = mapper._deferred_inheritance_condition(needs_tables)
             statement = sql.select(needs_tables, cond, use_labels=True)
             params = {}
             for c in param_names:
                 params[c.name] = mapper.get_attr_by_column(instance, c)
+            return (statement, params)
+            
+        strategy = self.parent_property._get_strategy(DeferredColumnLoader)
 
-            result = session.execute(statement, params, mapper=mapper)
-            try:
-                row = result.fetchone()
-                for prop in group:
-                    sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]])
-                return attributes.ATTR_WAS_SET
-            finally:
-                result.close()
+        props = [p for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables]
+        return strategy.setup_loader(instance, props=props, create_statement=create_statement)
 
-        return load
 
 ColumnLoader.logger = logging.class_logger(ColumnLoader)
 
@@ -127,7 +119,7 @@ class DeferredColumnLoader(LoaderStrategy):
     def create_row_processor(self, selectcontext, mapper, row):
         if self.group is not None and selectcontext.attributes.get(('undefer', self.group), False):
             return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, mapper, row)
-        elif not self.is_default or len(selectcontext.options):
+        elif not self.is_class_level or len(selectcontext.options):
             def execute(instance, row, isnew, **flags):
                 if isnew:
                     if self._should_log_debug:
@@ -151,6 +143,7 @@ class DeferredColumnLoader(LoaderStrategy):
         self._should_log_debug = logging.is_debug_enabled(self.logger)
 
     def init_class_attribute(self):
+        self.is_class_level = True
         self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
         sessionlib.attribute_manager.register_attribute(self.parent.class_, self.key, uselist=False, callable_=self.setup_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
 
@@ -158,23 +151,29 @@ class DeferredColumnLoader(LoaderStrategy):
         if self.group is not None and context.attributes.get(('undefer', self.group), False):
             self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs)
         
-    def setup_loader(self, instance):
+    def setup_loader(self, instance, props=None, create_statement=None):
         localparent = mapper.object_mapper(instance, raiseerror=False)
         if localparent is None:
             return None
-            
+
+        # adjust for the ColumnProperty associated with the instance
+        # not being our own ColumnProperty.  This can occur when entity_name
+        # mappers are used to map different versions of the same ColumnProperty
+        # to the class.
         prop = localparent.get_property(self.key)
         if prop is not self.parent_property:
             return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
-
+            
         def lazyload():
             if not mapper.has_identity(instance):
                 return None
-
-            if self.group is not None:
+            
+            if props is not None:
+                group = props
+            elif self.group is not None:
                 group = [p for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
             else:
-                group = None
+                group = [self.parent_property]
                 
             if self._should_log_debug:
                 self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join([p.key for p in group]) or 'None'))
@@ -182,28 +181,25 @@ class DeferredColumnLoader(LoaderStrategy):
             session = sessionlib.object_session(instance)
             if session is None:
                 raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
-                
-            clause = localparent._get_clause
-            ident = instance._instance_key[1]
-            params = {}
-            for i, primary_key in enumerate(localparent.primary_key):
-                params[primary_key._label] = ident[i]
-            if group is not None:
+
+            if create_statement is None:
+                clause = localparent._get_clause
+                ident = instance._instance_key[1]
+                params = {}
+                for i, primary_key in enumerate(localparent.primary_key):
+                    params[primary_key._label] = ident[i]
                 statement = sql.select([p.columns[0] for p in group], clause, from_obj=[localparent.mapped_table], use_labels=True)
             else:
-                statement = sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True)
-                    
-            if group is not None:
-                result = session.execute(statement, params, mapper=localparent)
-                try:
-                    row = result.fetchone()
-                    for prop in group:
-                        sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]])
-                    return attributes.ATTR_WAS_SET
-                finally:
-                    result.close()
-            else:
-                return session.scalar(sql.select([self.columns[0]], clause, from_obj=[localparent.mapped_table], use_labels=True),params, mapper=localparent)
+                statement, params = create_statement()
+                
+            result = session.execute(statement, params, mapper=localparent)
+            try:
+                row = result.fetchone()
+                for prop in group:
+                    sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]])
+                return attributes.ATTR_WAS_SET
+            finally:
+                result.close()
 
         return lazyload
                 
@@ -245,18 +241,16 @@ class AbstractRelationLoader(LoaderStrategy):
 
 class NoLoader(AbstractRelationLoader):
     def init_class_attribute(self):
+        self.is_class_level = True
         self._register_attribute(self.parent.class_)
 
     def create_row_processor(self, selectcontext, mapper, row):
-        if not self.is_default or len(selectcontext.options):
-            def execute(instance, row, isnew, **flags):
-                if isnew:
-                    if self._should_log_debug:
-                        self.logger.debug("set instance-level no loader on %s" % mapperutil.attribute_str(instance, self.key))
-                    self._init_instance_attribute(instance)
-            return (execute, None)
-        else:
-            return (None, None)
+        def execute(instance, row, isnew, **flags):
+            if isnew:
+                if self._should_log_debug:
+                    self.logger.debug("initializing blank scalar/collection on %s" % mapperutil.attribute_str(instance, self.key))
+                self._init_instance_attribute(instance)
+        return (execute, None)
 
 NoLoader.logger = logging.class_logger(NoLoader)
         
@@ -275,15 +269,21 @@ class LazyLoader(AbstractRelationLoader):
             self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads")
 
     def init_class_attribute(self):
+        self.is_class_level = True
         self._register_attribute(self.parent.class_, callable_=lambda i: self.setup_loader(i))
 
     def setup_loader(self, instance, options=None):
         if not mapper.has_mapper(instance):
             return None
         else:
+            # adjust for the PropertyLoader associated with the instance
+            # not being our own PropertyLoader.  This can occur when entity_name
+            # mappers are used to map different versions of the same PropertyLoader
+            # to the class.
             prop = mapper.object_mapper(instance).get_property(self.key)
             if prop is not self.parent_property:
                 return prop._get_strategy(LazyLoader).setup_loader(instance)
+
         def lazyload():
             self.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
             params = {}
@@ -351,13 +351,13 @@ class LazyLoader(AbstractRelationLoader):
         return lazyload
 
     def create_row_processor(self, selectcontext, mapper, row):
-        if not self.is_default or len(selectcontext.options):
+        if not self.is_class_level or len(selectcontext.options):
             def execute(instance, row, isnew, **flags):
                 if isnew:
                     if self._should_log_debug:
                         self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
                     # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
-                    # which will override the clareset_instance_attributess-level behavior
+                    # which will override the class-level behavior
                     self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options))
             return (execute, None)
         else:
@@ -435,13 +435,16 @@ class EagerLoader(AbstractRelationLoader):
     
     def init(self):
         super(EagerLoader, self).init()
-        if self.is_default:
-            self.parent._eager_loaders.add(self.parent_property)
-
         self.clauses = {}
         self.join_depth = self.parent_property.join_depth
 
     def init_class_attribute(self):
+        # class-level eager strategy; add the PropertyLoader
+        # to the parent's list of "eager loaders"; this tells the Query
+        # that eager loaders will be used in a normal query
+        self.parent._eager_loaders.add(self.parent_property)
+        
+        # initialize a lazy loader on the class level attribute
         self.parent_property._get_strategy(LazyLoader).init_class_attribute()
         
     def setup_query(self, context, parentclauses=None, parentmapper=None, **kwargs):
@@ -455,7 +458,6 @@ class EagerLoader(AbstractRelationLoader):
             path = parentclauses.path + (self.parent.base_mapper(), self.key)
         else:
             path = (self.parent.base_mapper(), self.key)
-
         
         if self.join_depth:
             if len(path) / 2 > self.join_depth:
index ae626db849cce57ebc4e0c060fab29ea5104260d..a8ce797f24a2023328b420365ec5fbfac23ff152 100644 (file)
@@ -527,6 +527,7 @@ class DefaultTest(UnitOfWorkTest):
             hohotype = Integer
             self.hohoval = 9
             self.althohoval = 15
+            
         global default_table
         metadata = MetaData(db)
         default_table = Table('default_test', metadata,
@@ -539,25 +540,42 @@ class DefaultTest(UnitOfWorkTest):
     def tearDownAll(self):
         default_table.drop()
         UnitOfWorkTest.tearDownAll(self)
+        
     def testinsert(self):
         class Hoho(object):pass
         assign_mapper(Hoho, default_table)
+        
         h1 = Hoho(hoho=self.althohoval)
         h2 = Hoho(counter=12)
         h3 = Hoho(hoho=self.althohoval, counter=12)
         h4 = Hoho()
         h5 = Hoho(foober='im the new foober')
         ctx.current.flush()
+        
         self.assert_(h1.hoho==self.althohoval)
         self.assert_(h3.hoho==self.althohoval)
-        self.assert_(h2.hoho==h4.hoho==h5.hoho==self.hohoval)
-        self.assert_(h3.counter == h2.counter == 12)
-        self.assert_(h1.counter ==  h4.counter==h5.counter==7)
-        self.assert_(h2.foober == h3.foober == h4.foober == 'im foober')
-        self.assert_(h5.foober=='im the new foober')
+        
+        def go():
+            # test deferred load of attribues, one select per instance
+            self.assert_(h2.hoho==h4.hoho==h5.hoho==self.hohoval)
+        self.assert_sql_count(testbase.db, go, 3)
+        
+        def go():
+            self.assert_(h1.counter ==  h4.counter==h5.counter==7)
+        self.assert_sql_count(testbase.db, go, 1)
+        
+        def go():
+            self.assert_(h3.counter == h2.counter == 12)
+            self.assert_(h2.foober == h3.foober == h4.foober == 'im foober')
+            self.assert_(h5.foober=='im the new foober')
+        self.assert_sql_count(testbase.db, go, 0)
+        
         ctx.current.clear()
+        
         l = Query(Hoho).select()
+        
         (h1, h2, h3, h4, h5) = l
+        
         self.assert_(h1.hoho==self.althohoval)
         self.assert_(h3.hoho==self.althohoval)
         self.assert_(h2.hoho==h4.hoho==h5.hoho==self.hohoval)
@@ -570,11 +588,15 @@ class DefaultTest(UnitOfWorkTest):
         # populates the PassiveDefaults explicitly so there is no "post-update"
         class Hoho(object):pass
         assign_mapper(Hoho, default_table)
+        
         h1 = Hoho(hoho="15", counter="15")
+        
         ctx.current.flush()
-        self.assert_(h1.hoho=="15")
-        self.assert_(h1.counter=="15")
-        self.assert_(h1.foober=="im foober")
+        def go():
+            self.assert_(h1.hoho=="15")
+            self.assert_(h1.counter=="15")
+            self.assert_(h1.foober=="im foober")
+        self.assert_sql_count(testbase.db, go, 0)
         
     def testupdate(self):
         class Hoho(object):pass
@@ -1320,6 +1342,7 @@ class ManyToManyTest(UnitOfWorkTest):
         ctx.current.clear()
         item = ctx.current.query(Item).get(item.item_id)
         print [k1, k2]
+        print item.keywords
         assert item.keywords == [k1, k2]
         
     def testassociation(self):