]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- callcounts
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Dec 2010 23:27:23 +0000 (18:27 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 10 Dec 2010 23:27:23 +0000 (18:27 -0500)
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/state.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/sql/expression.py
test/orm/test_attributes.py
test/orm/test_mapper.py

index 482f2be50638dd4bccee895852ab615f9dd5d64c..6b235e5277f3e195edb462891cfbadc000873375 100644 (file)
@@ -408,7 +408,7 @@ class ScalarAttributeImpl(AttributeImpl):
         del dict_[self.key]
 
     def get_history(self, state, dict_, passive=PASSIVE_OFF):
-        return History.from_attribute(
+        return History.from_scalar_attribute(
             self, state, dict_.get(self.key, NO_VALUE))
 
     def set(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
@@ -470,8 +470,7 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl):
         else:
             v = dict_.get(self.key, NO_VALUE)
             
-        return History.from_attribute(
-            self, state, v)
+        return History.from_scalar_attribute(self, state, v)
 
     def check_mutable_modified(self, state, dict_):
         a, u, d = self.get_history(state, dict_)
@@ -509,17 +508,14 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 
     def __init__(self, class_, key, callable_, dispatch,
                     trackparent=False, extension=None, copy_function=None,
-                    compare_function=None, **kwargs):
+                    **kwargs):
         super(ScalarObjectAttributeImpl, self).__init__(
                                             class_, 
                                             key,
                                             callable_, dispatch, 
                                             trackparent=trackparent, 
                                             extension=extension,
-                                            compare_function=compare_function, 
                                             **kwargs)
-        if compare_function is None:
-            self.is_equal = mapperutil.identity_equal
 
     def delete(self, state, dict_):
         old = self.get(state, dict_)
@@ -528,13 +524,13 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 
     def get_history(self, state, dict_, passive=PASSIVE_OFF):
         if self.key in dict_:
-            return History.from_attribute(self, state, dict_[self.key])
+            return History.from_object_attribute(self, state, dict_[self.key])
         else:
             current = self.get(state, dict_, passive=passive)
             if current is PASSIVE_NO_RESULT:
                 return HISTORY_BLANK
             else:
-                return History.from_attribute(self, state, current)
+                return History.from_object_attribute(self, state, current)
 
     def get_all_pending(self, state, dict_):
         if self.key in dict_:
@@ -637,7 +633,7 @@ class CollectionAttributeImpl(AttributeImpl):
         if current is PASSIVE_NO_RESULT:
             return HISTORY_BLANK
         else:
-            return History.from_attribute(self, state, current)
+            return History.from_collection(self, state, current)
 
     def get_all_pending(self, state, dict_):
         # this is basically an inline 
@@ -974,48 +970,76 @@ class History(tuple):
              and instance_state(c) or None
              for c in self.deleted],
             )
-        
+
     @classmethod
-    def from_attribute(cls, attribute, state, current):
+    def from_scalar_attribute(cls, attribute, state, current):
         original = state.committed_state.get(attribute.key, NEVER_SET)
+        if current is NO_VALUE:
+            if (original is not None and
+                original is not NEVER_SET and
+                original is not NO_VALUE):
+                deleted = [original]
+            else:
+                deleted = ()
+            return cls((), (), deleted)
+        elif original is NO_VALUE:
+            return cls([current], (), ())
+        elif (original is NEVER_SET or
+              attribute.is_equal(current, original) is True):
+            # dont let ClauseElement expressions here trip things up
+            return cls((), [current], ())
+        else:
+            if original is not None:
+                deleted = [original]
+            else:
+                deleted = ()
+            return cls([current], (), deleted)
 
-        if hasattr(attribute, 'get_collection'):
-            current = attribute.get_collection(state, state.dict, current)
-            if original is NO_VALUE:
-                return cls(list(current), (), ())
-            elif original is NEVER_SET:
-                return cls((), list(current), ())
+    @classmethod
+    def from_object_attribute(cls, attribute, state, current):
+        original = state.committed_state.get(attribute.key, NEVER_SET)
+        
+        if current is NO_VALUE:
+            if (original is not None and
+                original is not NEVER_SET and
+                original is not NO_VALUE):
+                deleted = [original]
             else:
-                current_set = util.IdentitySet(current)
-                original_set = util.IdentitySet(original)
-
-                # ensure duplicates are maintained
-                return cls(
-                    [x for x in current if x not in original_set],
-                    [x for x in current if x in original_set],
-                    [x for x in original if x not in current_set]
-                )
+                deleted = ()
+            return cls((), (), deleted)
+        elif original is NO_VALUE:
+            return cls([current], (), ())
+        elif (original is NEVER_SET or
+                current is original):
+            return cls((), [current], ())
         else:
-            if current is NO_VALUE:
-                if (original is not None and
-                    original is not NEVER_SET and
-                    original is not NO_VALUE):
-                    deleted = [original]
-                else:
-                    deleted = ()
-                return cls((), (), deleted)
-            elif original is NO_VALUE:
-                return cls([current], (), ())
-            elif (original is NEVER_SET or
-                  attribute.is_equal(current, original) is True):
-                # dont let ClauseElement expressions here trip things up
-                return cls((), [current], ())
+            if original is not None:
+                deleted = [original]
             else:
-                if original is not None:
-                    deleted = [original]
-                else:
-                    deleted = ()
-                return cls([current], (), deleted)
+                deleted = ()
+            return cls([current], (), deleted)
+
+    @classmethod
+    def from_collection(cls, attribute, state, current):
+        original = state.committed_state.get(attribute.key, NEVER_SET)
+        current = attribute.get_collection(state, state.dict, current)
+        
+        if original is NO_VALUE:
+            return cls(list(current), (), ())
+        elif original is NEVER_SET:
+            return cls((), list(current), ())
+        else:
+            current_states = [(instance_state(c), c) for c in current]
+            original_states = [(instance_state(c), c) for c in original]
+            
+            current_set = dict(current_states)
+            original_set = dict(original_states)
+            
+            return cls(
+                [o for s, o in current_states if s not in original_set],
+                [o for s, o in current_states if s in original_set],
+                [o for s, o in original_states if s not in current_set]
+            )
 
 HISTORY_BLANK = History(None, None, None)
 
index 90d105dc94bf893e263ebf7699dc7501c0cac539..0694ee6966b6922b87b4437f415c2fa389f82d83 100644 (file)
@@ -82,6 +82,8 @@ class MapperProperty(object):
     
     """
 
+    set_col_value = None
+    
     def setup(self, context, entity, path, adapter, **kwargs):
         """Called by Query for the purposes of constructing a SQL statement.
 
index 05e904a0dee29094afaa7ccb5fa1ba388d263405..a4662770eb6a3a4bd307c6bb92c208e814ecb38c 100644 (file)
@@ -1246,9 +1246,13 @@ class Mapper(object):
                         self.primary_key_from_instance(instance))
 
     def _identity_key_from_state(self, state):
-        return self.identity_key_from_primary_key(
-                        self._primary_key_from_state(state))
-
+        dict_ = state.dict
+        
+        return self._identity_class, tuple([
+            self._get_state_attr_by_column(state, dict_, col)
+            for col in self.primary_key
+        ])    
+        
     def primary_key_from_instance(self, instance):
         """Return the list of primary key values for the given
         instance.
@@ -1270,7 +1274,11 @@ class Mapper(object):
         return value
 
     def _set_state_attr_by_column(self, state, dict_, column, value):
-        return self._columntoproperty[column]._setattr(state, dict_, value, column)
+        prop = self._columntoproperty[column]
+        if prop.set_col_value:
+            prop.set_col_value(state, dict_, value, column)
+        else:
+            state.manager[prop.key].impl.set(state, dict_, value, None)
 
     def _get_committed_attr_by_column(self, obj, column):
         state = attributes.instance_state(obj)
@@ -1833,15 +1841,17 @@ class Mapper(object):
 
                     if primary_key is not None:
                         # set primary key attributes
-                        for pk, col in zip(primary_key, mapper._pks_by_table[table]):
+                        for pk, col in zip(primary_key, 
+                                        mapper._pks_by_table[table]):
                             # TODO: make sure this inlined code is OK
                             # with composites
                             prop = mapper._columntoproperty[col]
                             if state_dict.get(prop.key) is None:
                                 # TODO: would rather say:
-                                # state_dict[prop.key] = pk
-                                # here, one test fails
-                                prop._setattr(state, state_dict, pk, col)
+                                #state_dict[prop.key] = pk
+                                mapper._set_state_attr_by_column(state, 
+                                                                state_dict, 
+                                                                col, pk)
                                 
                     mapper._postfetch(uowtransaction, table, 
                                         state, state_dict, c,
index 776ecaf99548e6f024e5bc4636e23fd7006066d3..d628d87dcec753a191aebe978c4b90891001cd6b 100644 (file)
@@ -123,9 +123,6 @@ class ColumnProperty(StrategizedProperty):
         return state.get_impl(self.key).\
                     get_committed_value(state, dict_, passive=passive)
 
-    def _setattr(self, state, dict_, value, column):
-        state.manager[self.key].impl.set(state, dict_, value, None)
-
     def merge(self, session, source_state, source_dict, dest_state, 
                                 dest_dict, load, _recursive):
         if self.key in source_dict:
@@ -194,9 +191,8 @@ class CompositeProperty(ColumnProperty):
         obj = state.get_impl(self.key).\
                         get_committed_value(state, dict_, passive=passive)
         return self.get_col_value(column, obj)
-
-    def _setattr(self, state, dict_, value, column):
-
+    
+    def set_col_value(self, state, dict_, value, column):
         obj = state.get_impl(self.key).get(state, dict_)
         if obj is None:
             obj = self.composite_class(*[None for c in self.columns])
index 3984c0e0c467b23825e88c76058797f44a3ffbf0..278f86749ba0ce57120e4d9a6dbb82d1346f4486 100644 (file)
@@ -115,7 +115,7 @@ class InstanceState(object):
             raise
 
     def get_history(self, key, **kwargs):
-        return self.manager.get_impl(key).get_history(self, self.dict, **kwargs)
+        return self.manager[key].impl.get_history(self, self.dict, **kwargs)
 
     def get_impl(self, key):
         return self.manager[key].impl
index 05298767de8d96c94720afa8dd19b9ca504bb50c..6b6aeff9b445e17d2a7c998e1821deb702099cdc 100644 (file)
@@ -28,8 +28,9 @@ def populate(source, source_mapper, dest, dest_mapper,
         # how often this logic is invoked for memory/performance
         # reasons, since we only need this info for a primary key
         # destination.
-        if l.primary_key and r.primary_key and \
-                    r.references(l) and flag_cascaded_pks:
+        if flag_cascaded_pks and l.primary_key and \
+                    r.primary_key and \
+                    r.references(l):
             uowcommit.attributes[("pk_cascaded", dest, r)] = True
 
 def clear(dest, dest_mapper, synchronize_pairs):
index 7ee633f3e56eb8763cdc402a3c030894beac72fd..875ce634ba855397eb7005023af67a82abc9b0eb 100644 (file)
@@ -146,27 +146,35 @@ class UOWTransaction(object):
         
     def get_attribute_history(self, state, key, passive=True):
         """facade to attributes.get_state_history(), including caching of results."""
-        
+
         hashkey = ("history", state, key)
 
         # cache the objects, not the states; the strong reference here
         # prevents newly loaded objects from being dereferenced during the
         # flush process
+        
         if hashkey in self.attributes:
-            (history, cached_passive) = self.attributes[hashkey]
-            # if the cached lookup was "passive" and now we want non-passive, do a non-passive
-            # lookup and re-cache
+            history, state_history, cached_passive = self.attributes[hashkey]
+            # if the cached lookup was "passive" and now 
+            # we want non-passive, do a non-passive lookup and re-cache
             if cached_passive and not passive:
-                history = state.get_history(key, passive=False)
-                self.attributes[hashkey] = (history, passive)
+                impl = state.manager[key].impl
+                history = impl.get_history(state, state.dict, passive=False)
+                if history and impl.uses_objects:
+                    state_history = history.as_state()
+                else:
+                    state_history = history
+                self.attributes[hashkey] = (history, state_history, passive)
         else:
-            history = state.get_history(key, passive=passive)
-            self.attributes[hashkey] = (history, passive)
-
-        if not history or not state.get_impl(key).uses_objects:
-            return history
-        else:
-            return history.as_state()
+            impl = state.manager[key].impl
+            history = impl.get_history(state, state.dict, passive=passive)
+            if history and impl.uses_objects:
+                state_history = history.as_state()
+            else:
+                state_history = history
+            self.attributes[hashkey] = (history, state_history, passive)
+        
+        return state_history
     
     def has_dep(self, processor):
         return (processor, True) in self.presort_actions
index 27fa36768f3cd58678001a8f9bd8891c560b36d4..bf055f0b2f0e50c54fd194a6f7f60407b450cf00 100644 (file)
@@ -2013,11 +2013,12 @@ class ColumnCollection(util.OrderedProperties):
 
     def __init__(self, *cols):
         super(ColumnCollection, self).__init__()
-        self.update((c.key, c) for c in cols)
-
+        self._data.update((c.key, c) for c in cols)
+        self.__dict__['_all_cols'] = util.column_set(self)
+        
     def __str__(self):
         return repr([str(c) for c in self])
-
+    
     def replace(self, column):
         """add the given column to this collection, removing unaliased
            versions of this column  as well as existing columns with the
@@ -2037,8 +2038,12 @@ class ColumnCollection(util.OrderedProperties):
         if column.name in self and column.key != column.name:
             other = self[column.name]
             if other.name == other.key:
-                del self[other.name]
-        util.OrderedProperties.__setitem__(self, column.key, column)
+                del self._data[other.name]
+                self._all_cols.remove(other)
+        if column.key in self._data:
+            self._all_cols.remove(self._data[column.key])
+        self._all_cols.add(column)
+        self._data[column.key] = column
 
     def add(self, column):
         """Add a column to this collection.
@@ -2048,7 +2053,13 @@ class ColumnCollection(util.OrderedProperties):
 
         """
         self[column.key] = column
+    
+    def __delitem__(self, key):
+        raise NotImplementedError()
 
+    def __setattr__(self, key, object):
+        raise NotImplementedError()
+        
     def __setitem__(self, key, value):
         if key in self:
 
@@ -2062,14 +2073,25 @@ class ColumnCollection(util.OrderedProperties):
                           'another column with the same key.  Consider '
                           'use_labels for select() statements.' % (key,
                           getattr(existing, 'table', None)))
-        util.OrderedProperties.__setitem__(self, key, value)
+            self._all_cols.remove(existing)
+        self._all_cols.add(value)
+        self._data[key] = value
 
+    def clear(self):
+        self._data.clear()
+        self._all_cols.clear()
+        
     def remove(self, column):
-        del self[column.key]
+        del self._data[column.key]
+        self._all_cols.remove(column)
 
+    def update(self, value):
+        self._data.update(value)
+        self._all_cols.clear()
+        self._all_cols.update(self._data.values())
+        
     def extend(self, iter):
-        for c in iter:
-            self.add(c)
+        self.update((c.key, c) for c in iter)
 
     __hash__ = None
     
@@ -2086,20 +2108,21 @@ class ColumnCollection(util.OrderedProperties):
             raise exc.ArgumentError("__contains__ requires a string argument")
         return util.OrderedProperties.__contains__(self, other)
 
-    def contains_column(self, col):
-
-        # have to use a Set here, because it will compare the identity
-        # of the column, not just using "==" for comparison which will
-        # always return a "True" value (i.e. a BinaryClause...)
+    def __setstate__(self, state):
+        self.__dict__['_data'] = state['_data']
+        self.__dict__['_all_cols'] = util.column_set(self._data.values())
 
-        return col in util.column_set(self)
+    def contains_column(self, col):
+        # this has to be done via set() membership
+        return col in self._all_cols
     
     def as_immutable(self):
-        return ImmutableColumnCollection(self._data)
+        return ImmutableColumnCollection(self._data, self._all_cols)
         
 class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
-    def __init__(self, data):
+    def __init__(self, data, colset):
         util.ImmutableProperties.__init__(self, data)
+        self.__dict__['_all_cols'] = colset
     
     extend = remove = util.ImmutableProperties._immutable
 
index c0481f96b49fb2d5a487019a039dce104903d098..b543e79a124173d2a28dbeda5aa167483906df45 100644 (file)
@@ -1186,6 +1186,7 @@ class HistoryTest(_base.ORMTest):
                 assert False
 
         instrumentation.register_class(Foo)
+        instrumentation.register_class(Bar)
         attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True)
 
         hi = Bar(name='hi')
@@ -1238,6 +1239,7 @@ class HistoryTest(_base.ORMTest):
         from sqlalchemy.orm.collections import attribute_mapped_collection
 
         instrumentation.register_class(Foo)
+        instrumentation.register_class(Bar)
         attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True, typecallable=attribute_mapped_collection('name'))
 
         hi = Bar(name='hi')
@@ -1266,6 +1268,7 @@ class HistoryTest(_base.ORMTest):
         instrumentation.register_class(Foo)
         attributes.register_attribute(Foo, 'someattr', uselist=True, useobject=True)
         attributes.register_attribute(Foo, 'id', uselist=False, useobject=False)
+        instrumentation.register_class(Bar)
 
         hi = Bar(name='hi')
         there = Bar(name='there')
index 10c3b3abe22efb13f0e4e8e6e866b59c4d85544f..c94ef9b3fd0026862b9b3258fef7413e04d61095 100644 (file)
@@ -2280,7 +2280,8 @@ class CompositeTypesTest(_base.MappedTest):
                 return (self.id, self.version)
             __hash__ = None
             def __eq__(self, other):
-                return other.id == self.id and other.version == self.version
+                return isinstance(other, Version) and other.id == self.id and \
+                                other.version == self.version
             def __ne__(self, other):
                 return not self.__eq__(other)