]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- moved class-level attributes placed by the attributes package into a _class_state
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Dec 2007 00:31:26 +0000 (00:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Dec 2007 00:31:26 +0000 (00:31 +0000)
variable attached to the class.
- mappers track themselves primarily using the "mappers" collection on _class_state.
ClassKey is gone and mapper lookup uses regular dict keyed to entity_name; removes
a fair degree of WeakKeyDictionary overhead as well as ClassKey overhead.
- mapper_registry renamed to _mapper_registry; is only consulted by the
compile_mappers(), mapper.compile() and clear_mappers() functions/methods.

CHANGES
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/entity.py

diff --git a/CHANGES b/CHANGES
index 6e275254f3e6e586751fc7879c270be8b881160f..a785ac1a97bacc8ff7f80c4a81a43dd2261adaad 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -37,7 +37,8 @@ CHANGES
    - several ORM attributes have been removed or made private:
      mapper.get_attr_by_column(), mapper.set_attr_by_column(), 
      mapper.pks_by_table, mapper.cascade_callable(), 
-     MapperProperty.cascade_callable(), mapper.canload()
+     MapperProperty.cascade_callable(), mapper.canload(),
+     mapper._mapper_registry, attributes.AttributeManager
      
    - fixed endless loop issue when using lazy="dynamic" on both 
      sides of a bi-directional relationship [ticket:872]
index 9e42b1214891f08cab828072855f72391bf056fd..7f5672371bd8f2299ba846ab70138311ec31498f 100644 (file)
@@ -11,7 +11,7 @@ constructors.
 """
 
 from sqlalchemy import util as sautil
-from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, mapper_registry
+from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, _mapper_registry
 from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, EXT_STOP, EXT_PASS, ExtensionOption, PropComparator
 from sqlalchemy.orm.properties import SynonymProperty, PropertyLoader, ColumnProperty, CompositeProperty, BackRef
 from sqlalchemy.orm import mapper as mapperlib
@@ -21,7 +21,7 @@ from sqlalchemy.orm.util import polymorphic_union, create_row_adapter
 from sqlalchemy.orm.session import Session as _Session
 from sqlalchemy.orm.session import object_session, sessionmaker
 from sqlalchemy.orm.scoping import ScopedSession
-
+from itertools import chain
 
 __all__ = [ 'relation', 'column_property', 'composite', 'backref', 'eagerload',
             'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer',
@@ -567,9 +567,9 @@ def compile_mappers():
     This is equivalent to calling ``compile()`` on any individual mapper.
     """
 
-    if not mapper_registry:
+    if not _mapper_registry:
         return
-    mapper_registry.values()[0].compile()
+    _mapper_registry.values()[0][0].compile()
 
 def clear_mappers():
     """Remove all mappers that have been created thus far.
@@ -579,10 +579,9 @@ def clear_mappers():
     """
     mapperlib._COMPILE_MUTEX.acquire()
     try:
-        for mapper in mapper_registry.values():
+        for mapper in chain(*_mapper_registry.values()):
             mapper.dispose()
-        mapper_registry.clear()
-        mapperlib.ClassKey.dispose(mapperlib.ClassKey)
+        _mapper_registry.clear()
         from sqlalchemy.orm import dependency
         dependency.MapperStub.dispose(dependency.MapperStub)
     finally:
index 7b4d286e8d04e3a34da6d55b3e51847e572d98f6..5e3747e002e5b164e81653658e29b5da83d6355f 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import operator, weakref, threading
+import weakref, threading, operator
 from itertools import chain
 import UserDict
 from sqlalchemy import util
@@ -12,13 +12,15 @@ from sqlalchemy.orm import interfaces, collections
 from sqlalchemy.orm.util import identity_equal
 from sqlalchemy import exceptions
 
-
 PASSIVE_NORESULT = object()
 ATTR_WAS_SET = object()
 NO_VALUE = object()
 
 class InstrumentedAttribute(interfaces.PropComparator):
-    """public-facing instrumented attribute."""
+    """public-facing instrumented attribute, placed in the 
+    class dictionary.
+    
+    """
     
     def __init__(self, impl, comparator=None):
         """Construct an InstrumentedAttribute.
@@ -29,19 +31,19 @@ class InstrumentedAttribute(interfaces.PropComparator):
         self.impl = impl
         self.comparator = comparator
 
-    def __set__(self, obj, value):
-        self.impl.set(obj._state, value, None)
+    def __set__(self, instance, value):
+        self.impl.set(instance._state, value, None)
 
-    def __delete__(self, obj):
-        self.impl.delete(obj._state)
+    def __delete__(self, instance):
+        self.impl.delete(instance._state)
 
-    def __get__(self, obj, owner):
-        if obj is None:
+    def __get__(self, instance, owner):
+        if instance is None:
             return self
-        return self.impl.get(obj._state)
+        return self.impl.get(instance._state)
 
-    def get_history(self, obj, **kwargs):
-        return self.impl.get_history(obj._state, **kwargs)
+    def get_history(self, instance, **kwargs):
+        return self.impl.get_history(instance._state, **kwargs)
         
     def clause_element(self):
         return self.comparator.clause_element()
@@ -64,6 +66,10 @@ class InstrumentedAttribute(interfaces.PropComparator):
     property = property(_property, doc="the MapperProperty object associated with this attribute")
 
 class ProxiedAttribute(InstrumentedAttribute):
+    """a 'proxy' attribute which adds InstrumentedAttribute
+    class-level behavior to any user-defined class property.
+    """
+    
     class ProxyImpl(object):
         def __init__(self, key):
             self.key = key
@@ -76,17 +82,15 @@ class ProxiedAttribute(InstrumentedAttribute):
         self.comparator = comparator
         self.key = key
         self.impl = ProxiedAttribute.ProxyImpl(key)
-    def __get__(self, obj, owner):
-        if obj is None:
-            self.user_prop.__get__(obj, owner)                
+    def __get__(self, instance, owner):
+        if instance is None:
+            self.user_prop.__get__(instance, owner)                
             return self
-        return self.user_prop.__get__(obj, owner)
-    def __set__(self, obj, value):
-        return self.user_prop.__set__(obj, value)
-    def __delete__(self, obj):
-        return self.user_prop.__delete__(obj)
-
-        
+        return self.user_prop.__get__(instance, owner)
+    def __set__(self, instance, value):
+        return self.user_prop.__set__(instance, value)
+    def __delete__(self, instance):
+        return self.user_prop.__delete__(instance)
     
 class AttributeImpl(object):
     """internal implementation for instrumented attributes."""
@@ -131,7 +135,7 @@ class AttributeImpl(object):
         self.trackparent = trackparent
         self.mutable_scalars = mutable_scalars
         if mutable_scalars:
-            class_._sa_has_mutable_scalars = True
+            class_._class_state.has_mutable_scalars = True
         self.copy = None
         if compare_function is None:
             self.is_equal = operator.eq
@@ -276,17 +280,17 @@ class AttributeImpl(object):
         state.modified = True
         if self.trackparent and value is not None:
             self.sethasparent(value._state, True)
-        obj = state.obj()
+        instance = state.obj()
         for ext in self.extensions:
-            ext.append(obj, value, initiator or self)
+            ext.append(instance, value, initiator or self)
 
     def fire_remove_event(self, state, value, initiator):
         state.modified = True
         if self.trackparent and value is not None:
             self.sethasparent(value._state, False)
-        obj = state.obj()
+        instance = state.obj()
         for ext in self.extensions:
-            ext.remove(obj, value, initiator or self)
+            ext.remove(instance, value, initiator or self)
 
     def fire_replace_event(self, state, value, previous, initiator):
         state.modified = True
@@ -295,9 +299,9 @@ class AttributeImpl(object):
                 self.sethasparent(value._state, True)
             if previous is not None:
                 self.sethasparent(previous._state, False)
-        obj = state.obj()
+        instance = state.obj()
         for ext in self.extensions:
-            ext.set(obj, value, previous, initiator or self)
+            ext.set(instance, value, previous, initiator or self)
 
 class ScalarAttributeImpl(AttributeImpl):
     """represents a scalar value-holding InstrumentedAttribute."""
@@ -331,7 +335,7 @@ class ScalarAttributeImpl(AttributeImpl):
             return False
 
     def set(self, state, value, initiator):
-        """Set a value on the given object.
+        """Set a value on the given InstanceState.
 
         `initiator` is the ``InstrumentedAttribute`` that initiated the
         ``set()` operation and is used to control the depth of a circular
@@ -367,7 +371,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         self.fire_remove_event(state, old, self)
 
     def set(self, state, value, initiator):
-        """Set a value on the given object.
+        """Set a value on the given InstanceState.
 
         `initiator` is the ``InstrumentedAttribute`` that initiated the
         ``set()` operation and is used to control the depth of a circular
@@ -542,7 +546,7 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
     def __init__(self, key):
         self.key = key
 
-    def set(self, obj, child, oldchild, initiator):
+    def set(self, instance, child, oldchild, initiator):
         if oldchild is child:
             return
         if oldchild is not None:
@@ -550,18 +554,25 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
             # present when updating via a backref.
             impl = getattr(oldchild.__class__, self.key).impl
             try:                
-                impl.remove(oldchild._state, obj, initiator)
+                impl.remove(oldchild._state, instance, initiator)
             except (ValueError, KeyError, IndexError):
                 pass
         if child is not None:
-            getattr(child.__class__, self.key).impl.append(child._state, obj, initiator)
+            getattr(child.__class__, self.key).impl.append(child._state, instance, initiator)
 
-    def append(self, obj, child, initiator):
-        getattr(child.__class__, self.key).impl.append(child._state, obj, initiator)
+    def append(self, instance, child, initiator):
+        getattr(child.__class__, self.key).impl.append(child._state, instance, initiator)
 
-    def remove(self, obj, child, initiator):
-        getattr(child.__class__, self.key).impl.remove(child._state, obj, initiator)
+    def remove(self, instance, child, initiator):
+        getattr(child.__class__, self.key).impl.remove(child._state, instance, initiator)
 
+class ClassState(object):
+    """tracks state information at the class level."""
+    def __init__(self):
+        self.mappers = {}
+        self.attrs = {}
+        self.has_mutable_scalars = False
+        
 class InstanceState(object):
     """tracks state information at the instance level."""
 
@@ -583,7 +594,7 @@ class InstanceState(object):
         instance_dict = self.instance_dict
         if instance_dict is None:
             return
-            
+        
         instance_dict = instance_dict()
         if instance_dict is None:
             return
@@ -599,11 +610,15 @@ class InstanceState(object):
             id2 = self.instance_dict
             if id2 is None or id2() is None or self.obj() is not None:
                 return
-                
-            self.__resurrect(instance_dict)
+            
+            try:
+                self.__resurrect(instance_dict)
+            except:
+                # catch GC exceptions
+                pass
         finally:
             instance_dict._mutex.release()
-    
+            
     def _check_resurrect(self, instance_dict):
         instance_dict._mutex.acquire()
         try:
@@ -614,8 +629,8 @@ class InstanceState(object):
     def is_modified(self):
         if self.modified:
             return True
-        elif getattr(self.class_, '_sa_has_mutable_scalars', False):
-            for attr in managed_attributes(self.class_):
+        elif self.class_._class_state.has_mutable_scalars:
+            for attr in _managed_attributes(self.class_):
                 if getattr(attr.impl, 'mutable_scalars', False) and attr.impl.check_mutable_modified(self):
                     return True
             else:
@@ -669,7 +684,7 @@ class InstanceState(object):
         if not hasattr(self, 'expired_attributes'):
             self.expired_attributes = util.Set()
         if attribute_names is None:
-            for attr in managed_attributes(self.class_):
+            for attr in _managed_attributes(self.class_):
                 self.dict.pop(attr.impl.key, None)
                 self.callables[attr.impl.key] = self.__fire_trigger
                 self.expired_attributes.add(attr.impl.key)
@@ -707,7 +722,7 @@ class InstanceState(object):
         
         self.committed_state = {}
         self.modified = False
-        for attr in managed_attributes(self.class_):
+        for attr in _managed_attributes(self.class_):
             attr.impl.commit_to_state(self)
         # remove strong ref
         self._strong_obj = None
@@ -802,9 +817,9 @@ class WeakInstanceDict(UserDict.UserDict):
         
     def itervalues(self):
         for state in self.data.itervalues():
-            obj = state.obj()
-            if obj is not None:
-                yield obj
+            instance = state.obj()
+            if instance is not None:
+                yield instance
 
     def values(self):
         L = []
@@ -841,8 +856,6 @@ class AttributeHistory(object):
     particular instance.
     """
 
-    NO_VALUE = object()
-    
     def __init__(self, attr, state, current, passive=False):
         self.attr = attr
 
@@ -905,26 +918,19 @@ class AttributeHistory(object):
     def deleted_items(self):
         return list(self._deleted_items)
 
-def managed_attributes(class_):
+def _managed_attributes(class_):
     """return all InstrumentedAttributes associated with the given class_ and its superclasses."""
     
-    return chain(*[getattr(cl, '_sa_attrs', []) for cl in class_.__mro__[:-1]])
+    return chain(*[cl._class_state.attrs.values() for cl in class_.__mro__[:-1] if hasattr(cl, '_class_state')])
 
-def noninherited_managed_attributes(class_):
-    """return all InstrumentedAttributes associated with the given class_, but not its superclasses."""
-
-    return getattr(class_, '_sa_attrs', [])
-
-def is_modified(obj):
-    return obj._state.is_modified()
-
-        
-def get_history(obj, key, **kwargs):
+def is_modified(instance):
+    return instance._state.is_modified()
 
-    return getattr(obj.__class__, key).impl.get_history(obj._state, **kwargs)
+def get_history(instance, key, **kwargs):
+    return getattr(instance.__class__, key).impl.get_history(instance._state, **kwargs)
 
-def get_as_list(obj, key, passive=False):
-    """Return an attribute of the given name from the given object.
+def get_as_list(instance, key, passive=False):
+    """Return an attribute of the given name from the given instance.
 
     If the attribute is a scalar, return it as a single-item list,
     otherwise return a collection based attribute.
@@ -934,8 +940,8 @@ def get_as_list(obj, key, passive=False):
     `passive` flag is False.
     """
 
-    attr = getattr(obj.__class__, key).impl
-    state = obj._state
+    attr = getattr(instance.__class__, key).impl
+    state = instance._state
     x = attr.get(state, passive=passive)
     if x is PASSIVE_NORESULT:
         return []
@@ -946,8 +952,8 @@ def get_as_list(obj, key, passive=False):
     else:
         return [x]
 
-def has_parent(class_, obj, key, optimistic=False):
-    return getattr(class_, key).impl.hasparent(obj._state, optimistic=optimistic)
+def has_parent(class_, instance, key, optimistic=False):
+    return getattr(class_, key).impl.hasparent(instance._state, optimistic=optimistic)
 
 def _create_prop(class_, key, uselist, callable_, typecallable, useobject, **kwargs):
     if kwargs.pop('dynamic', False):
@@ -962,12 +968,12 @@ def _create_prop(class_, key, uselist, callable_, typecallable, useobject, **kwa
         return ScalarAttributeImpl(class_, key, callable_,
                                            **kwargs)
 
-def manage(obj):
+def manage(instance):
     """initialize an InstanceState on the given instance."""
     
-    if not hasattr(obj, '_state'):
-        obj._state = InstanceState(obj)
-        
+    if not hasattr(instance, '_state'):
+        instance._state = InstanceState(instance)
+
 def new_instance(class_, state=None):
     """create a new instance of class_ without its __init__() method being called.
     
@@ -981,12 +987,18 @@ def new_instance(class_, state=None):
         s._state = InstanceState(s)
     return s
     
+def _init_class_state(class_):
+    if not '_class_state' in class_.__dict__:
+        class_._class_state = ClassState()
+    
 def register_class(class_, extra_init=None, on_exception=None):
     # do a sweep first, this also helps some attribute extensions
     # (like associationproxy) become aware of themselves at the 
     # class level
     for key in dir(class_):
         getattr(class_, key, None)
+
+    _init_class_state(class_)
     
     oldinit = None
     doinit = False
@@ -1032,15 +1044,15 @@ def unregister_class(class_):
         else:
             delattr(class_, '__init__')
     
-    for attr in noninherited_managed_attributes(class_):
-        if attr.impl.key in class_.__dict__ and isinstance(class_.__dict__[attr.impl.key], InstrumentedAttribute):
-            delattr(class_, attr.impl.key)
-    if '_sa_attrs' in class_.__dict__:
-        delattr(class_, '_sa_attrs')
+    if '_class_state' in class_.__dict__:
+        _class_state = class_.__dict__['_class_state']
+        for key, attr in _class_state.attrs.iteritems():
+            if key in class_.__dict__:
+                delattr(class_, attr.impl.key)
+        delattr(class_, '_class_state')
 
 def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_property=None, **kwargs):
-    if not '_sa_attrs' in class_.__dict__:
-        class_._sa_attrs = []
+    _init_class_state(class_)
         
     typecallable = kwargs.pop('typecallable', None)
     if isinstance(typecallable, InstrumentedAttribute):
@@ -1060,18 +1072,16 @@ def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_pr
                                        typecallable=typecallable, **kwargs), comparator=comparator)
     
     setattr(class_, key, inst)
-    class_._sa_attrs.append(inst)
+    class_._class_state.attrs[key] = inst
 
 def unregister_attribute(class_, key):
-    if key in class_.__dict__:
-        attr = getattr(class_, key)
-        if isinstance(attr, InstrumentedAttribute):
-            class_._sa_attrs.remove(attr)
-            delattr(class_, key)
+    class_state = class_._class_state
+    if key in class_state.attrs:
+        del class_._class_state.attrs[key]
+        delattr(class_, key)
 
 def init_collection(instance, key):
     """Initialize a collection attribute and return the collection adapter."""
-
     attr = getattr(instance.__class__, key).impl
     state = instance._state
     user_data = attr.initialize(state)
index 67087c5708b97e5f99a276d278bbb7d88f80d78b..c69881622bac8b1b10c93eea8af04070dc060cc2 100644 (file)
@@ -14,10 +14,9 @@ from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter
 from sqlalchemy.orm import sync, attributes
 from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator
 
-__all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry']
+__all__ = ['Mapper', 'class_mapper', 'object_mapper', '_mapper_registry']
 
-# a dictionary mapping classes to their primary mappers
-mapper_registry = weakref.WeakKeyDictionary()
+_mapper_registry = weakref.WeakKeyDictionary()
 
 # a list of MapperExtensions that will be installed in all mappers by default
 global_extensions = []
@@ -88,7 +87,6 @@ class Mapper(object):
 
         self.class_ = class_
         self.entity_name = entity_name
-        self.class_key = ClassKey(class_, entity_name)
         self.primary_key_argument = primary_key
         self.non_primary = non_primary
         self.order_by = order_by
@@ -206,7 +204,10 @@ class Mapper(object):
         self.__props_init = True
         if hasattr(self.class_, 'c'):
             del self.class_.c
-        attributes.unregister_class(self.class_)
+        if not self.non_primary and self.entity_name in self._class_state.mappers:
+            del self._class_state.mappers[self.entity_name]
+        if not self._class_state.mappers:
+            attributes.unregister_class(self.class_)
         
     def compile(self):
         """Compile this mapper into its final internal format.
@@ -220,7 +221,7 @@ class Mapper(object):
             if self.__props_init:
                 return self
             # initialize properties on all mappers
-            for mapper in mapper_registry.values():
+            for mapper in chain(*_mapper_registry.values()):
                 if not mapper.__props_init:
                     mapper.__initialize_properties()
 
@@ -718,11 +719,13 @@ class Mapper(object):
         """
 
         if self.non_primary:
+            self._class_state = self.class_._class_state
             return
 
-        if not self.non_primary and (self.class_key in mapper_registry):
+        if not self.non_primary and '_class_state' in self.class_.__dict__ and (self.entity_name in self.class_._class_state.mappers):
              raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'.  Use non_primary=True to create a non primary Mapper.  clear_mappers() will remove *all* current mappers from all classes." % (self.class_, self.entity_name))
 
+            
         def extra_init(class_, oldinit, instance, args, kwargs):
             self.compile()
             if 'init_instance' in self.extension.methods:
@@ -732,10 +735,15 @@ class Mapper(object):
             util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
 
         attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception)
+        
+        self._class_state = self.class_._class_state
+        if self._class_state not in _mapper_registry:
+            _mapper_registry[self._class_state] = []
 
         _COMPILE_MUTEX.acquire()
         try:
-            mapper_registry[self.class_key] = self
+            _mapper_registry[self._class_state].append(self)
+            self.class_._class_state.mappers[self.entity_name] = self
         finally:
             _COMPILE_MUTEX.release()
 
@@ -806,11 +814,11 @@ class Mapper(object):
     def _is_primary_mapper(self):
         """Return True if this mapper is the primary mapper for its class key (class + entity_name)."""
         # FIXME: cant we just look at "non_primary" flag ?
-        return mapper_registry.get(self.class_key, None) is self
+        return self._class_state.mappers[self.entity_name] is self
 
     def primary_mapper(self):
         """Return the primary mapper corresponding to this mapper's class key (class + entity_name)."""
-        return mapper_registry[self.class_key]
+        return self._class_state.mappers[self.entity_name]
 
     def is_assigned(self, instance):
         """Return True if this mapper handles the given instance.
@@ -1485,26 +1493,6 @@ class Mapper(object):
 Mapper.logger = logging.class_logger(Mapper)
 
 
-class ClassKey(object):
-    """Key a class and an entity name to a mapper, via the mapper_registry."""
-
-    __metaclass__ = util.ArgSingleton
-
-    def __init__(self, class_, entity_name):
-        self.class_ = class_
-        self.entity_name = entity_name
-        self._hash = hash((self.class_, self.entity_name))
-        
-    def __hash__(self):
-        return self._hash
-
-    def __eq__(self, other):
-        return self is other
-
-    def __repr__(self):
-        return "ClassKey(%s, %s)" % (repr(self.class_), repr(self.entity_name))
-
-    
 def has_identity(object):
     return hasattr(object, '_instance_key')
 
@@ -1533,7 +1521,7 @@ def object_mapper(object, entity_name=None, raiseerror=True):
     """
 
     try:
-        mapper = mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', entity_name))]
+        mapper = object.__class__._class_state.mappers[getattr(object, '_entity_name', entity_name)]
     except (KeyError, AttributeError):
         if raiseerror:
             raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', entity_name)))
@@ -1548,7 +1536,7 @@ def class_mapper(class_, entity_name=None, compile=True):
     """
 
     try:
-        mapper = mapper_registry[ClassKey(class_, entity_name)]
+        mapper = class_._class_state.mappers[entity_name]
     except (KeyError, AttributeError):
         raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name))
     if compile:
index 28ef39aba419dca53f09d072f67fc7651bd1fb77..bb025a3ab3b51e9fb98e574cd5eecee20be480e4 100644 (file)
@@ -1076,7 +1076,7 @@ class Session(object):
         not be loaded in the course of performing this test.
         """
 
-        for attr in attributes.managed_attributes(instance.__class__):
+        for attr in attributes._managed_attributes(instance.__class__):
             if not include_collections and hasattr(attr.impl, 'get_collection'):
                 continue
             if attr.get_history(instance).is_modified():
index aa0ab1ca96ad032b16e71cfc81215645e430ca24..457404b6f2ee8299045fb5923f758c8d9bbe18fc 100644 (file)
@@ -147,7 +147,7 @@ class UnitOfWork(object):
             if x not in self.deleted 
             and (
                 x._state.modified
-                or (getattr(x.__class__, '_sa_has_mutable_scalars', False) and x.state.is_modified())
+                or (x.__class__._class_state.has_mutable_scalars and x.state.is_modified())
             )
             ])
 
@@ -162,7 +162,7 @@ class UnitOfWork(object):
 
         dirty = [x for x in self.identity_map.all_states()
             if x.modified
-            or (getattr(x.class_, '_sa_has_mutable_scalars', False) and x.is_modified())
+            or (x.class_._class_state.has_mutable_scalars and x.is_modified())
         ]
         
         if len(dirty) == 0 and len(self.deleted) == 0 and len(self.new) == 0:
index 5ef01b8829229bc560ce24ded222008fdcdff237..ce267189f24c83f0df1d449a664232926ebae55b 100644 (file)
@@ -43,6 +43,7 @@ class EntityTest(AssertMixin):
     def tearDownAll(self):
         metadata.drop_all()
     def tearDown(self):
+        ctx.current.clear()
         clear_mappers()
         for t in metadata.table_iterator(reverse=True):
             t.delete().execute()