From: Mike Bayer Date: Sun, 2 Dec 2007 00:31:26 +0000 (+0000) Subject: - moved class-level attributes placed by the attributes package into a _class_state X-Git-Tag: rel_0_4_2~116 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=0ec4e7d6b35685ba4b5d9e2053c765984b4a9189;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - moved class-level attributes placed by the attributes package into a _class_state variable attached to the class. - mappers track themselves primarily using the "mappers" collection on _class_state. ClassKey is gone and mapper lookup uses regular dict keyed to entity_name; removes a fair degree of WeakKeyDictionary overhead as well as ClassKey overhead. - mapper_registry renamed to _mapper_registry; is only consulted by the compile_mappers(), mapper.compile() and clear_mappers() functions/methods. --- diff --git a/CHANGES b/CHANGES index 6e275254f3..a785ac1a97 100644 --- a/CHANGES +++ b/CHANGES @@ -37,7 +37,8 @@ CHANGES - several ORM attributes have been removed or made private: mapper.get_attr_by_column(), mapper.set_attr_by_column(), mapper.pks_by_table, mapper.cascade_callable(), - MapperProperty.cascade_callable(), mapper.canload() + MapperProperty.cascade_callable(), mapper.canload(), + mapper._mapper_registry, attributes.AttributeManager - fixed endless loop issue when using lazy="dynamic" on both sides of a bi-directional relationship [ticket:872] diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 9e42b12148..7f5672371b 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -11,7 +11,7 @@ constructors. """ from sqlalchemy import util as sautil -from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, mapper_registry +from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, _mapper_registry from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, EXT_STOP, EXT_PASS, ExtensionOption, PropComparator from sqlalchemy.orm.properties import SynonymProperty, PropertyLoader, ColumnProperty, CompositeProperty, BackRef from sqlalchemy.orm import mapper as mapperlib @@ -21,7 +21,7 @@ from sqlalchemy.orm.util import polymorphic_union, create_row_adapter from sqlalchemy.orm.session import Session as _Session from sqlalchemy.orm.session import object_session, sessionmaker from sqlalchemy.orm.scoping import ScopedSession - +from itertools import chain __all__ = [ 'relation', 'column_property', 'composite', 'backref', 'eagerload', 'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer', @@ -567,9 +567,9 @@ def compile_mappers(): This is equivalent to calling ``compile()`` on any individual mapper. """ - if not mapper_registry: + if not _mapper_registry: return - mapper_registry.values()[0].compile() + _mapper_registry.values()[0][0].compile() def clear_mappers(): """Remove all mappers that have been created thus far. @@ -579,10 +579,9 @@ def clear_mappers(): """ mapperlib._COMPILE_MUTEX.acquire() try: - for mapper in mapper_registry.values(): + for mapper in chain(*_mapper_registry.values()): mapper.dispose() - mapper_registry.clear() - mapperlib.ClassKey.dispose(mapperlib.ClassKey) + _mapper_registry.clear() from sqlalchemy.orm import dependency dependency.MapperStub.dispose(dependency.MapperStub) finally: diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 7b4d286e8d..5e3747e002 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import operator, weakref, threading +import weakref, threading, operator from itertools import chain import UserDict from sqlalchemy import util @@ -12,13 +12,15 @@ from sqlalchemy.orm import interfaces, collections from sqlalchemy.orm.util import identity_equal from sqlalchemy import exceptions - PASSIVE_NORESULT = object() ATTR_WAS_SET = object() NO_VALUE = object() class InstrumentedAttribute(interfaces.PropComparator): - """public-facing instrumented attribute.""" + """public-facing instrumented attribute, placed in the + class dictionary. + + """ def __init__(self, impl, comparator=None): """Construct an InstrumentedAttribute. @@ -29,19 +31,19 @@ class InstrumentedAttribute(interfaces.PropComparator): self.impl = impl self.comparator = comparator - def __set__(self, obj, value): - self.impl.set(obj._state, value, None) + def __set__(self, instance, value): + self.impl.set(instance._state, value, None) - def __delete__(self, obj): - self.impl.delete(obj._state) + def __delete__(self, instance): + self.impl.delete(instance._state) - def __get__(self, obj, owner): - if obj is None: + def __get__(self, instance, owner): + if instance is None: return self - return self.impl.get(obj._state) + return self.impl.get(instance._state) - def get_history(self, obj, **kwargs): - return self.impl.get_history(obj._state, **kwargs) + def get_history(self, instance, **kwargs): + return self.impl.get_history(instance._state, **kwargs) def clause_element(self): return self.comparator.clause_element() @@ -64,6 +66,10 @@ class InstrumentedAttribute(interfaces.PropComparator): property = property(_property, doc="the MapperProperty object associated with this attribute") class ProxiedAttribute(InstrumentedAttribute): + """a 'proxy' attribute which adds InstrumentedAttribute + class-level behavior to any user-defined class property. + """ + class ProxyImpl(object): def __init__(self, key): self.key = key @@ -76,17 +82,15 @@ class ProxiedAttribute(InstrumentedAttribute): self.comparator = comparator self.key = key self.impl = ProxiedAttribute.ProxyImpl(key) - def __get__(self, obj, owner): - if obj is None: - self.user_prop.__get__(obj, owner) + def __get__(self, instance, owner): + if instance is None: + self.user_prop.__get__(instance, owner) return self - return self.user_prop.__get__(obj, owner) - def __set__(self, obj, value): - return self.user_prop.__set__(obj, value) - def __delete__(self, obj): - return self.user_prop.__delete__(obj) - - + return self.user_prop.__get__(instance, owner) + def __set__(self, instance, value): + return self.user_prop.__set__(instance, value) + def __delete__(self, instance): + return self.user_prop.__delete__(instance) class AttributeImpl(object): """internal implementation for instrumented attributes.""" @@ -131,7 +135,7 @@ class AttributeImpl(object): self.trackparent = trackparent self.mutable_scalars = mutable_scalars if mutable_scalars: - class_._sa_has_mutable_scalars = True + class_._class_state.has_mutable_scalars = True self.copy = None if compare_function is None: self.is_equal = operator.eq @@ -276,17 +280,17 @@ class AttributeImpl(object): state.modified = True if self.trackparent and value is not None: self.sethasparent(value._state, True) - obj = state.obj() + instance = state.obj() for ext in self.extensions: - ext.append(obj, value, initiator or self) + ext.append(instance, value, initiator or self) def fire_remove_event(self, state, value, initiator): state.modified = True if self.trackparent and value is not None: self.sethasparent(value._state, False) - obj = state.obj() + instance = state.obj() for ext in self.extensions: - ext.remove(obj, value, initiator or self) + ext.remove(instance, value, initiator or self) def fire_replace_event(self, state, value, previous, initiator): state.modified = True @@ -295,9 +299,9 @@ class AttributeImpl(object): self.sethasparent(value._state, True) if previous is not None: self.sethasparent(previous._state, False) - obj = state.obj() + instance = state.obj() for ext in self.extensions: - ext.set(obj, value, previous, initiator or self) + ext.set(instance, value, previous, initiator or self) class ScalarAttributeImpl(AttributeImpl): """represents a scalar value-holding InstrumentedAttribute.""" @@ -331,7 +335,7 @@ class ScalarAttributeImpl(AttributeImpl): return False def set(self, state, value, initiator): - """Set a value on the given object. + """Set a value on the given InstanceState. `initiator` is the ``InstrumentedAttribute`` that initiated the ``set()` operation and is used to control the depth of a circular @@ -367,7 +371,7 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): self.fire_remove_event(state, old, self) def set(self, state, value, initiator): - """Set a value on the given object. + """Set a value on the given InstanceState. `initiator` is the ``InstrumentedAttribute`` that initiated the ``set()` operation and is used to control the depth of a circular @@ -542,7 +546,7 @@ class GenericBackrefExtension(interfaces.AttributeExtension): def __init__(self, key): self.key = key - def set(self, obj, child, oldchild, initiator): + def set(self, instance, child, oldchild, initiator): if oldchild is child: return if oldchild is not None: @@ -550,18 +554,25 @@ class GenericBackrefExtension(interfaces.AttributeExtension): # present when updating via a backref. impl = getattr(oldchild.__class__, self.key).impl try: - impl.remove(oldchild._state, obj, initiator) + impl.remove(oldchild._state, instance, initiator) except (ValueError, KeyError, IndexError): pass if child is not None: - getattr(child.__class__, self.key).impl.append(child._state, obj, initiator) + getattr(child.__class__, self.key).impl.append(child._state, instance, initiator) - def append(self, obj, child, initiator): - getattr(child.__class__, self.key).impl.append(child._state, obj, initiator) + def append(self, instance, child, initiator): + getattr(child.__class__, self.key).impl.append(child._state, instance, initiator) - def remove(self, obj, child, initiator): - getattr(child.__class__, self.key).impl.remove(child._state, obj, initiator) + def remove(self, instance, child, initiator): + getattr(child.__class__, self.key).impl.remove(child._state, instance, initiator) +class ClassState(object): + """tracks state information at the class level.""" + def __init__(self): + self.mappers = {} + self.attrs = {} + self.has_mutable_scalars = False + class InstanceState(object): """tracks state information at the instance level.""" @@ -583,7 +594,7 @@ class InstanceState(object): instance_dict = self.instance_dict if instance_dict is None: return - + instance_dict = instance_dict() if instance_dict is None: return @@ -599,11 +610,15 @@ class InstanceState(object): id2 = self.instance_dict if id2 is None or id2() is None or self.obj() is not None: return - - self.__resurrect(instance_dict) + + try: + self.__resurrect(instance_dict) + except: + # catch GC exceptions + pass finally: instance_dict._mutex.release() - + def _check_resurrect(self, instance_dict): instance_dict._mutex.acquire() try: @@ -614,8 +629,8 @@ class InstanceState(object): def is_modified(self): if self.modified: return True - elif getattr(self.class_, '_sa_has_mutable_scalars', False): - for attr in managed_attributes(self.class_): + elif self.class_._class_state.has_mutable_scalars: + for attr in _managed_attributes(self.class_): if getattr(attr.impl, 'mutable_scalars', False) and attr.impl.check_mutable_modified(self): return True else: @@ -669,7 +684,7 @@ class InstanceState(object): if not hasattr(self, 'expired_attributes'): self.expired_attributes = util.Set() if attribute_names is None: - for attr in managed_attributes(self.class_): + for attr in _managed_attributes(self.class_): self.dict.pop(attr.impl.key, None) self.callables[attr.impl.key] = self.__fire_trigger self.expired_attributes.add(attr.impl.key) @@ -707,7 +722,7 @@ class InstanceState(object): self.committed_state = {} self.modified = False - for attr in managed_attributes(self.class_): + for attr in _managed_attributes(self.class_): attr.impl.commit_to_state(self) # remove strong ref self._strong_obj = None @@ -802,9 +817,9 @@ class WeakInstanceDict(UserDict.UserDict): def itervalues(self): for state in self.data.itervalues(): - obj = state.obj() - if obj is not None: - yield obj + instance = state.obj() + if instance is not None: + yield instance def values(self): L = [] @@ -841,8 +856,6 @@ class AttributeHistory(object): particular instance. """ - NO_VALUE = object() - def __init__(self, attr, state, current, passive=False): self.attr = attr @@ -905,26 +918,19 @@ class AttributeHistory(object): def deleted_items(self): return list(self._deleted_items) -def managed_attributes(class_): +def _managed_attributes(class_): """return all InstrumentedAttributes associated with the given class_ and its superclasses.""" - return chain(*[getattr(cl, '_sa_attrs', []) for cl in class_.__mro__[:-1]]) + return chain(*[cl._class_state.attrs.values() for cl in class_.__mro__[:-1] if hasattr(cl, '_class_state')]) -def noninherited_managed_attributes(class_): - """return all InstrumentedAttributes associated with the given class_, but not its superclasses.""" - - return getattr(class_, '_sa_attrs', []) - -def is_modified(obj): - return obj._state.is_modified() - - -def get_history(obj, key, **kwargs): +def is_modified(instance): + return instance._state.is_modified() - return getattr(obj.__class__, key).impl.get_history(obj._state, **kwargs) +def get_history(instance, key, **kwargs): + return getattr(instance.__class__, key).impl.get_history(instance._state, **kwargs) -def get_as_list(obj, key, passive=False): - """Return an attribute of the given name from the given object. +def get_as_list(instance, key, passive=False): + """Return an attribute of the given name from the given instance. If the attribute is a scalar, return it as a single-item list, otherwise return a collection based attribute. @@ -934,8 +940,8 @@ def get_as_list(obj, key, passive=False): `passive` flag is False. """ - attr = getattr(obj.__class__, key).impl - state = obj._state + attr = getattr(instance.__class__, key).impl + state = instance._state x = attr.get(state, passive=passive) if x is PASSIVE_NORESULT: return [] @@ -946,8 +952,8 @@ def get_as_list(obj, key, passive=False): else: return [x] -def has_parent(class_, obj, key, optimistic=False): - return getattr(class_, key).impl.hasparent(obj._state, optimistic=optimistic) +def has_parent(class_, instance, key, optimistic=False): + return getattr(class_, key).impl.hasparent(instance._state, optimistic=optimistic) def _create_prop(class_, key, uselist, callable_, typecallable, useobject, **kwargs): if kwargs.pop('dynamic', False): @@ -962,12 +968,12 @@ def _create_prop(class_, key, uselist, callable_, typecallable, useobject, **kwa return ScalarAttributeImpl(class_, key, callable_, **kwargs) -def manage(obj): +def manage(instance): """initialize an InstanceState on the given instance.""" - if not hasattr(obj, '_state'): - obj._state = InstanceState(obj) - + if not hasattr(instance, '_state'): + instance._state = InstanceState(instance) + def new_instance(class_, state=None): """create a new instance of class_ without its __init__() method being called. @@ -981,12 +987,18 @@ def new_instance(class_, state=None): s._state = InstanceState(s) return s +def _init_class_state(class_): + if not '_class_state' in class_.__dict__: + class_._class_state = ClassState() + def register_class(class_, extra_init=None, on_exception=None): # do a sweep first, this also helps some attribute extensions # (like associationproxy) become aware of themselves at the # class level for key in dir(class_): getattr(class_, key, None) + + _init_class_state(class_) oldinit = None doinit = False @@ -1032,15 +1044,15 @@ def unregister_class(class_): else: delattr(class_, '__init__') - for attr in noninherited_managed_attributes(class_): - if attr.impl.key in class_.__dict__ and isinstance(class_.__dict__[attr.impl.key], InstrumentedAttribute): - delattr(class_, attr.impl.key) - if '_sa_attrs' in class_.__dict__: - delattr(class_, '_sa_attrs') + if '_class_state' in class_.__dict__: + _class_state = class_.__dict__['_class_state'] + for key, attr in _class_state.attrs.iteritems(): + if key in class_.__dict__: + delattr(class_, attr.impl.key) + delattr(class_, '_class_state') def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_property=None, **kwargs): - if not '_sa_attrs' in class_.__dict__: - class_._sa_attrs = [] + _init_class_state(class_) typecallable = kwargs.pop('typecallable', None) if isinstance(typecallable, InstrumentedAttribute): @@ -1060,18 +1072,16 @@ def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_pr typecallable=typecallable, **kwargs), comparator=comparator) setattr(class_, key, inst) - class_._sa_attrs.append(inst) + class_._class_state.attrs[key] = inst def unregister_attribute(class_, key): - if key in class_.__dict__: - attr = getattr(class_, key) - if isinstance(attr, InstrumentedAttribute): - class_._sa_attrs.remove(attr) - delattr(class_, key) + class_state = class_._class_state + if key in class_state.attrs: + del class_._class_state.attrs[key] + delattr(class_, key) def init_collection(instance, key): """Initialize a collection attribute and return the collection adapter.""" - attr = getattr(instance.__class__, key).impl state = instance._state user_data = attr.initialize(state) diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 67087c5708..c69881622b 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -14,10 +14,9 @@ from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter from sqlalchemy.orm import sync, attributes from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator -__all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry'] +__all__ = ['Mapper', 'class_mapper', 'object_mapper', '_mapper_registry'] -# a dictionary mapping classes to their primary mappers -mapper_registry = weakref.WeakKeyDictionary() +_mapper_registry = weakref.WeakKeyDictionary() # a list of MapperExtensions that will be installed in all mappers by default global_extensions = [] @@ -88,7 +87,6 @@ class Mapper(object): self.class_ = class_ self.entity_name = entity_name - self.class_key = ClassKey(class_, entity_name) self.primary_key_argument = primary_key self.non_primary = non_primary self.order_by = order_by @@ -206,7 +204,10 @@ class Mapper(object): self.__props_init = True if hasattr(self.class_, 'c'): del self.class_.c - attributes.unregister_class(self.class_) + if not self.non_primary and self.entity_name in self._class_state.mappers: + del self._class_state.mappers[self.entity_name] + if not self._class_state.mappers: + attributes.unregister_class(self.class_) def compile(self): """Compile this mapper into its final internal format. @@ -220,7 +221,7 @@ class Mapper(object): if self.__props_init: return self # initialize properties on all mappers - for mapper in mapper_registry.values(): + for mapper in chain(*_mapper_registry.values()): if not mapper.__props_init: mapper.__initialize_properties() @@ -718,11 +719,13 @@ class Mapper(object): """ if self.non_primary: + self._class_state = self.class_._class_state return - if not self.non_primary and (self.class_key in mapper_registry): + if not self.non_primary and '_class_state' in self.class_.__dict__ and (self.entity_name in self.class_._class_state.mappers): raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'. Use non_primary=True to create a non primary Mapper. clear_mappers() will remove *all* current mappers from all classes." % (self.class_, self.entity_name)) + def extra_init(class_, oldinit, instance, args, kwargs): self.compile() if 'init_instance' in self.extension.methods: @@ -732,10 +735,15 @@ class Mapper(object): util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs) attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception) + + self._class_state = self.class_._class_state + if self._class_state not in _mapper_registry: + _mapper_registry[self._class_state] = [] _COMPILE_MUTEX.acquire() try: - mapper_registry[self.class_key] = self + _mapper_registry[self._class_state].append(self) + self.class_._class_state.mappers[self.entity_name] = self finally: _COMPILE_MUTEX.release() @@ -806,11 +814,11 @@ class Mapper(object): def _is_primary_mapper(self): """Return True if this mapper is the primary mapper for its class key (class + entity_name).""" # FIXME: cant we just look at "non_primary" flag ? - return mapper_registry.get(self.class_key, None) is self + return self._class_state.mappers[self.entity_name] is self def primary_mapper(self): """Return the primary mapper corresponding to this mapper's class key (class + entity_name).""" - return mapper_registry[self.class_key] + return self._class_state.mappers[self.entity_name] def is_assigned(self, instance): """Return True if this mapper handles the given instance. @@ -1485,26 +1493,6 @@ class Mapper(object): Mapper.logger = logging.class_logger(Mapper) -class ClassKey(object): - """Key a class and an entity name to a mapper, via the mapper_registry.""" - - __metaclass__ = util.ArgSingleton - - def __init__(self, class_, entity_name): - self.class_ = class_ - self.entity_name = entity_name - self._hash = hash((self.class_, self.entity_name)) - - def __hash__(self): - return self._hash - - def __eq__(self, other): - return self is other - - def __repr__(self): - return "ClassKey(%s, %s)" % (repr(self.class_), repr(self.entity_name)) - - def has_identity(object): return hasattr(object, '_instance_key') @@ -1533,7 +1521,7 @@ def object_mapper(object, entity_name=None, raiseerror=True): """ try: - mapper = mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', entity_name))] + mapper = object.__class__._class_state.mappers[getattr(object, '_entity_name', entity_name)] except (KeyError, AttributeError): if raiseerror: raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', entity_name))) @@ -1548,7 +1536,7 @@ def class_mapper(class_, entity_name=None, compile=True): """ try: - mapper = mapper_registry[ClassKey(class_, entity_name)] + mapper = class_._class_state.mappers[entity_name] except (KeyError, AttributeError): raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name)) if compile: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 28ef39aba4..bb025a3ab3 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1076,7 +1076,7 @@ class Session(object): not be loaded in the course of performing this test. """ - for attr in attributes.managed_attributes(instance.__class__): + for attr in attributes._managed_attributes(instance.__class__): if not include_collections and hasattr(attr.impl, 'get_collection'): continue if attr.get_history(instance).is_modified(): diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index aa0ab1ca96..457404b6f2 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -147,7 +147,7 @@ class UnitOfWork(object): if x not in self.deleted and ( x._state.modified - or (getattr(x.__class__, '_sa_has_mutable_scalars', False) and x.state.is_modified()) + or (x.__class__._class_state.has_mutable_scalars and x.state.is_modified()) ) ]) @@ -162,7 +162,7 @@ class UnitOfWork(object): dirty = [x for x in self.identity_map.all_states() if x.modified - or (getattr(x.class_, '_sa_has_mutable_scalars', False) and x.is_modified()) + or (x.class_._class_state.has_mutable_scalars and x.is_modified()) ] if len(dirty) == 0 and len(self.deleted) == 0 and len(self.new) == 0: diff --git a/test/orm/entity.py b/test/orm/entity.py index 5ef01b8829..ce267189f2 100644 --- a/test/orm/entity.py +++ b/test/orm/entity.py @@ -43,6 +43,7 @@ class EntityTest(AssertMixin): def tearDownAll(self): metadata.drop_all() def tearDown(self): + ctx.current.clear() clear_mappers() for t in metadata.table_iterator(reverse=True): t.delete().execute()