all mappers including non primaries, and is strictly used for the list of "to compile/dispose".
- all global references are now weak referencing. if you del a mapped class and any dependent classes,
its mapper and all dependencies fall out of scope.
- attributes.py still had issues which were barely covered by tests. added way more tests (coverage.py still says 71%, doh)
fixed things, took out unnecessary commit to states. attribute history is also asserted for ordering.
This is equivalent to calling ``compile()`` on any individual mapper.
"""
- if not _mapper_registry:
- return
- _mapper_registry.values()[0][0].compile()
+ for m in list(_mapper_registry):
+ m.compile()
def clear_mappers():
"""Remove all mappers that have been created thus far.
"""
mapperlib._COMPILE_MUTEX.acquire()
try:
- for mapper in chain(*_mapper_registry.values()):
+ for mapper in _mapper_registry:
mapper.dispose()
- _mapper_registry.clear()
- from sqlalchemy.orm import dependency
- dependency.MapperStub.dispose(dependency.MapperStub)
finally:
mapperlib._COMPILE_MUTEX.release()
return _create_history(self, state, current)
def fire_append_event(self, state, value, initiator):
- if self.key not in state.committed_state:
- if self.key in state.dict:
- state.committed_state[self.key] = self.copy(state.dict[self.key])
- else:
- state.committed_state[self.key] = NO_VALUE
+ if self.key not in state.committed_state and self.key in state.dict:
+ state.committed_state[self.key] = self.copy(state.dict[self.key])
+
state.modified = True
if self.trackparent and value is not None:
ext.append(instance, value, initiator or self)
def fire_pre_remove_event(self, state, initiator):
- if self.key not in state.committed_state:
- if self.key in state.dict:
- state.committed_state[self.key] = self.copy(state.dict[self.key])
- else:
- state.committed_state[self.key] = NO_VALUE
+ if self.key not in state.committed_state and self.key in state.dict:
+ state.committed_state[self.key] = self.copy(state.dict[self.key])
def fire_remove_event(self, state, value, initiator):
- if self.key not in state.committed_state:
- if self.key in state.dict:
- state.committed_state[self.key] = self.copy(state.dict[self.key])
- else:
- state.committed_state[self.key] = NO_VALUE
+ if self.key not in state.committed_state and self.key in state.dict:
+ state.committed_state[self.key] = self.copy(state.dict[self.key])
+
state.modified = True
if self.trackparent and value is not None:
if initiator is self:
return
- if self.key not in state.committed_state:
- if self.key in state.dict:
- state.committed_state[self.key] = self.copy(state.dict[self.key])
- else:
- state.committed_state[self.key] = NO_VALUE
-
collection = self.get_collection(state, passive=passive)
if collection is PASSIVE_NORESULT:
state.get_pending(self.key).append(value)
if initiator is self:
return
- if self.key not in state.committed_state:
- if self.key in state.dict:
- state.committed_state[self.key] = self.copy(state.dict[self.key])
- else:
- state.committed_state[self.key] = NO_VALUE
-
collection = self.get_collection(state, passive=passive)
if collection is PASSIVE_NORESULT:
state.get_pending(self.key).remove(value)
if initiator is self:
return
- if self.key not in state.committed_state:
- if self.key in state.dict:
- state.committed_state[self.key] = self.copy(state.dict[self.key])
- else:
- state.committed_state[self.key] = NO_VALUE
-
# we need a CollectionAdapter to adapt the incoming value to an
# assignable iterable. pulling a new collection first so that
# an adaptation exception does not trigger a lazy load of the
new_values = list(new_collection.adapt_like_to_iterable(value))
old = self.get(state)
+ state.committed_state[self.key] = self.copy(old)
+
old_collection = self.get_collection(state, old)
idset = util.IdentitySet
"""
collection, user_data = self._build_collection(state)
- self._load_collection(state, value or [], emit_events=False,
- collection=collection)
- value = user_data
- if state.committed_state is not None:
- state.commit_attr(self, value)
- # remove per-instance callable, if any
+ if value:
+ for item in value:
+ collection.append_without_event(item)
+
state.callables.pop(self.key, None)
- state.dict[self.key] = value
- return value
+ state.dict[self.key] = user_data
+
+ if self.key in state.pending:
+ # pending items. commit loaded data, add/remove new data
+ state.committed_state[self.key] = list(value or [])
+ added = state.pending[self.key].added_items
+ removed = state.pending[self.key].deleted_items
+ for item in added:
+ collection.append_without_event(item)
+ for item in removed:
+ collection.remove_without_event(item)
+ del state.pending[self.key]
+ elif self.key in state.committed_state:
+ # no pending items. remove committed state if any.
+ # (this can occur with an expired attribute)
+ del state.committed_state[self.key]
+
+ return user_data
def _build_collection(self, state):
"""build a new, blank collection and return it wrapped in a CollectionAdapter."""
collection = collections.CollectionAdapter(self, state, user_data)
return collection, user_data
- def _load_collection(self, state, values, emit_events=True, collection=None):
- """given an empty CollectionAdapter, load the collection with current values.
-
- Loads the collection from lazy callables in all cases.
- """
+ def get_collection(self, state, user_data=None, passive=False):
+ """retrieve the CollectionAdapter associated with the given state.
- collection = collection or self.get_collection(state)
- if values is None:
- return
-
- appender = emit_events and collection.append_with_event or collection.append_without_event
+ Creates a new CollectionAdapter if one does not exist.
- if self.key in state.pending:
- # move 'pending' items into the newly loaded collection
- added = state.pending[self.key].added_items
- removed = state.pending[self.key].deleted_items
- for item in values:
- if item not in removed:
- appender(item)
- for item in added:
- appender(item)
- del state.pending[self.key]
- else:
- for item in values:
- appender(item)
-
- def get_collection(self, state, user_data=None, passive=False):
- """retrieve the CollectionAdapter associated with the given state."""
+ """
if user_data is None:
user_data = self.get(state, passive=passive)
"""
__metaclass__ = util.ArgSingleton
-
+
def __init__(self, parent, mapper, key):
self.mapper = mapper
self.base_mapper = self
return False
def get_property(self, key, resolve_synonyms=False, raiseerr=True):
- """return MapperProperty with the given key."""
+ """return a MapperProperty associated with the given key."""
+
+ self.compile()
+ return self._get_property(key, resolve_synonyms=resolve_synonyms, raiseerr=raiseerr)
+
+ def _get_property(self, key, resolve_synonyms=False, raiseerr=True):
+ """private in-compilation version of get_property()."""
+
prop = self.__props.get(key, None)
if resolve_synonyms:
while isinstance(prop, SynonymProperty):
return prop
def iterate_properties(self):
+ self.compile()
return self.__props.itervalues()
iterate_properties = property(iterate_properties, doc="returns an iterator of all MapperProperty objects.")
raise NotImplementedError("Public collection of MapperProperty objects is provided by the get_property() and iterate_properties accessors.")
properties = property(properties)
+ compiled = property(lambda self:self.__props_init, doc="return True if this mapper is compiled")
+
def dispose(self):
# disaable any attribute-based compilation
self.__props_init = True
- if hasattr(self.class_, 'c'):
+ try:
del self.class_.c
+ except AttributeError:
+ pass
if not self.non_primary and self.entity_name in self._class_state.mappers:
del self._class_state.mappers[self.entity_name]
if not self._class_state.mappers:
# double-check inside mutex
if self.__props_init:
return self
+
# initialize properties on all mappers
- for mapper in chain(*_mapper_registry.values()):
+ for mapper in list(_mapper_registry):
if not mapper.__props_init:
mapper.__initialize_properties()
- # if we're not primary, compile us
- if self.non_primary:
- self.__initialize_properties()
-
return self
finally:
_COMPILE_MUTEX.release()
- def _check_compile(self):
- if self.non_primary and not self.__props_init:
- self.__initialize_properties()
- return self
-
def __initialize_properties(self):
"""Call the ``init()`` method on all ``MapperProperties``
attached to this mapper.
if self.non_primary:
self._class_state = self.class_._class_state
+ _mapper_registry[self] = True
return
if not self.non_primary and '_class_state' in self.class_.__dict__ and (self.entity_name in self.class_._class_state.mappers):
attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception)
self._class_state = self.class_._class_state
- if self._class_state not in _mapper_registry:
- _mapper_registry[self._class_state] = []
+ _mapper_registry[self] = True
- _COMPILE_MUTEX.acquire()
- try:
- _mapper_registry[self._class_state].append(self)
- self.class_._class_state.mappers[self.entity_name] = self
- finally:
- _COMPILE_MUTEX.release()
+ self.class_._class_state.mappers[self.entity_name] = self
for ext in util.to_list(self.extension, []):
ext.instrument_class(self, self.class_)
def do_init(self):
class_ = self.parent.class_
- aliased_property = self.parent.get_property(self.key, resolve_synonyms=True)
+ aliased_property = self.parent._get_property(self.key, resolve_synonyms=True)
self.logger.info("register managed attribute %s on class %s" % (self.key, class_.__name__))
if self.instrument is None:
class SynonymProp(object):
def _determine_targets(self):
if isinstance(self.argument, type):
- self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False)._check_compile()
+ self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False)
elif isinstance(self.argument, mapper.Mapper):
- self.mapper = self.argument._check_compile()
+ self.mapper = self.argument
else:
raise exceptions.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument)))
# ensure the "select_mapper", if different from the regular target mapper, is compiled.
- self.mapper.get_select_mapper()._check_compile()
+ self.mapper.get_select_mapper()
if not self.parent.concrete:
for inheriting in self.parent.iterate_to_root():
- if inheriting is not self.parent and inheriting.get_property(self.key, raiseerr=False):
+ if inheriting is not self.parent and inheriting._get_property(self.key, raiseerr=False):
warnings.warn(RuntimeWarning("Warning: relation '%s' on mapper '%s' supercedes the same relation on inherited mapper '%s'; this can cause dependency issues during flush" % (self.key, self.parent, inheriting)))
if self.association is not None:
if isinstance(self.association, type):
- self.association = mapper.class_mapper(self.association, entity_name=self.entity_name, compile=False)._check_compile()
+ self.association = mapper.class_mapper(self.association, entity_name=self.entity_name, compile=False)
self.target = self.mapper.mapped_table
self.select_mapper = self.mapper.get_select_mapper()
if self.backref is not None:
self.backref.compile(self)
- elif not mapper.class_mapper(self.parent.class_).get_property(self.key, raiseerr=False):
+ elif not mapper.class_mapper(self.parent.class_, compile=False)._get_property(self.key, raiseerr=False):
raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__))
super(PropertyLoader, self).do_init()
self.prop = prop
mapper = prop.mapper.primary_mapper()
- if mapper.get_property(self.key, raiseerr=False) is None:
+ if mapper._get_property(self.key, raiseerr=False) is None:
pj = self.kwargs.pop('primaryjoin', None)
sj = self.kwargs.pop('secondaryjoin', None)
mapper._compile_property(self.key, relation);
- prop.reverse_property = mapper.get_property(self.key)
- mapper.get_property(self.key).reverse_property = prop
+ prop.reverse_property = mapper._get_property(self.key)
+ mapper._get_property(self.key).reverse_property = prop
else:
raise exceptions.ArgumentError("Error creating backref '%s' on relation '%s': property of that name exists on mapper '%s'" % (self.key, prop, mapper))
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import itertools, sys, warnings, sets
+import itertools, sys, warnings, sets, weakref
import __builtin__
from sqlalchemy import exceptions, logging
yield elem
class ArgSingleton(type):
- instances = {}
+ instances = weakref.WeakValueDictionary()
def dispose(cls):
for key in list(ArgSingleton.instances):
x = Foo()
x._state.commit_all()
x.col2.append(bar4)
- (added, unchanged, deleted) = attributes.get_history(x._state, 'col2')
-
- self.assertEquals(set(unchanged), set([bar1, bar2, bar3]))
- self.assertEquals(added, [bar4])
+ self.assertEquals(attributes.get_history(x._state, 'col2'), ([bar4], [bar1, bar2, bar3], []))
def test_parenttrack(self):
class Foo(object):pass
self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [hi], []))
f._state.commit(['someattr'])
- self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([there, hi]), set()))
+ self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, there], []))
f.someattr.remove(there)
- self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([hi]), set([there])))
+ self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], [there]))
f.someattr.append(old)
f.someattr.append(new)
- self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([new, old]), set([hi]), set([there])))
+ self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old, new], [hi], [there]))
f._state.commit(['someattr'])
- self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([new, old, hi]), set()))
+ self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, old, new], []))
f.someattr.pop(0)
- self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set(), set([new, old]), set([hi])))
+ self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old, new], [hi]))
# case 2. object with direct settings (similar to a load operation)
f = Foo()
self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [new], []))
f._state.commit(['someattr'])
- self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([]), set([old, new]), set([])))
+ self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new, old], []))
f = Foo()
collection = attributes.init_collection(f, 'someattr')
b2 = Bar()
f1.bars.append(b2)
- self.assertEquals(tuple([set(x) for x in attributes.get_history(f1._state, 'bars')]), (set([b1, b2]), set([]), set([])))
+ self.assertEquals(attributes.get_history(f1._state, 'bars'), ([b1, b2], [], []))
self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], []))
self.assertEquals(attributes.get_history(b2._state, 'foo'), ([f1], [], []))
+
+ def test_lazy_backref_collections(self):
+ class Foo(fixtures.Base):
+ pass
+ class Bar(fixtures.Base):
+ pass
+
+ lazy_load = []
+ def lazyload(instance):
+ def load():
+ return lazy_load
+ return load
+
+ attributes.register_class(Foo)
+ attributes.register_class(Bar)
+ attributes.register_attribute(Foo, 'bars', uselist=True, extension=attributes.GenericBackrefExtension('foo'), trackparent=True, callable_=lazyload, useobject=True)
+ attributes.register_attribute(Bar, 'foo', uselist=False, extension=attributes.GenericBackrefExtension('bars'), trackparent=True, useobject=True)
+
+ bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)]
+ lazy_load = [bar1, bar2, bar3]
+
+ f = Foo()
+ bar4 = Bar()
+ bar4.foo = f
+ self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], []))
+
+ lazy_load = None
+ f = Foo()
+ bar4 = Bar()
+ bar4.foo = f
+ self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [], []))
+
+ lazy_load = [bar1, bar2, bar3]
+ f._state.trigger = lazyload(f)
+ f._state.expire_attributes(['bars'])
+ self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar2, bar3], []))
+
+ def test_collections_via_lazyload(self):
+ class Foo(fixtures.Base):
+ pass
+ class Bar(fixtures.Base):
+ pass
+
+ lazy_load = []
+ def lazyload(instance):
+ def load():
+ return lazy_load
+ return load
+
+ attributes.register_class(Foo)
+ attributes.register_class(Bar)
+ attributes.register_attribute(Foo, 'bars', uselist=True, callable_=lazyload, trackparent=True, useobject=True)
+
+ bar1, bar2, bar3, bar4 = [Bar(id=1), Bar(id=2), Bar(id=3), Bar(id=4)]
+ lazy_load = [bar1, bar2, bar3]
+
+ f = Foo()
+ f.bars = []
+ self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [], [bar1, bar2, bar3]))
+
+ f = Foo()
+ f.bars.append(bar4)
+ self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], []) )
+
+ f = Foo()
+ f.bars.remove(bar2)
+ self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2]))
+ f.bars.append(bar4)
+ self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar3], [bar2]))
+
+ f = Foo()
+ del f.bars[1]
+ self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2]))
+
+ lazy_load = None
+ f = Foo()
+ f.bars.append(bar2)
+ self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar2], [], []))
+
+ def test_scalar_via_lazyload(self):
+ class Foo(fixtures.Base):
+ pass
+ class Bar(fixtures.Base):
+ pass
+
+ lazy_load = None
+ def lazyload(instance):
+ def load():
+ return lazy_load
+ return load
+
+ attributes.register_class(Foo)
+ attributes.register_class(Bar)
+ attributes.register_attribute(Foo, 'bar', uselist=False, callable_=lazyload, trackparent=True, useobject=True)
+ bar1, bar2 = [Bar(id=1), Bar(id=2)]
+ lazy_load = bar1
+
+ f = Foo()
+ self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], []))
+ f = Foo()
+ f.bar = None
+ self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1]))
+
+ f = Foo()
+ f.bar = bar2
+ self.assertEquals(attributes.get_history(f._state, 'bar'), ([bar2], [], [bar1]))
+ f.bar = bar1
+ self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], []))
if __name__ == "__main__":
testbase.main()
mapper(Foo, addresses, inherits=User)
assert getattr(Foo().__class__, 'user_name').impl is not None
+ def test_compileon_getprops(self):
+ m =mapper(User, users)
+
+ assert not m.compiled
+ assert list(m.iterate_properties)
+ assert m.compiled
+ clear_mappers()
+
+ m= mapper(User, users)
+ assert not m.compiled
+ assert m.get_property('user_name')
+ assert m.compiled
+
def test_add_property(self):
assert_col = []
class User(object):
for a in alist:
sess.delete(a)
sess.flush()
- clear_mappers()
+
+ # dont need to clear_mappers()
+ del B
+ del A
+
+ metadata.create_all()
+ try:
+ go()
+ finally:
+ metadata.drop_all()
+
+ def test_with_manytomany(self):
+ metadata = MetaData(testbase.db)
+
+ table1 = Table("mytable", metadata,
+ Column('col1', Integer, primary_key=True),
+ Column('col2', String(30))
+ )
+
+ table2 = Table("mytable2", metadata,
+ Column('col1', Integer, primary_key=True),
+ Column('col2', String(30)),
+ )
+
+ table3 = Table('t1tot2', metadata,
+ Column('t1', Integer, ForeignKey('mytable.col1')),
+ Column('t2', Integer, ForeignKey('mytable2.col1')),
+ )
+
+ @profile_memory
+ def go():
+ class A(Base):
+ pass
+ class B(Base):
+ pass
+
+ mapper(A, table1, properties={
+ 'bs':relation(B, secondary=table3, backref='as')
+ })
+ mapper(B, table2)
+
+ sess = create_session()
+ a1 = A(col2='a1')
+ a2 = A(col2='a2')
+ b1 = B(col2='b1')
+ b2 = B(col2='b2')
+ a1.bs.append(b1)
+ a2.bs.append(b2)
+ for x in [a1,a2]:
+ sess.save(x)
+ sess.flush()
+ sess.clear()
+
+ alist = sess.query(A).all()
+ self.assertEquals(
+ [
+ A(bs=[B(col2='b1')]), A(bs=[B(col2='b2')])
+ ],
+ alist)
+
+ for a in alist:
+ sess.delete(a)
+ sess.flush()
+
+ # dont need to clear_mappers()
+ del B
+ del A
metadata.create_all()
try:
sess.flush()
a1 = u1.addresses[0]
+ self.assertEquals(select([addresses.c.username]).execute().fetchall(), [('jack',), ('jack',)])
+
assert sess.get(Address, a1.id) is u1.addresses[0]
u1.username = 'ed'
sess.flush()
assert u1.addresses[0].username == 'ed'
+ self.assertEquals(select([addresses.c.username]).execute().fetchall(), [('ed',), ('ed',)])
sess.clear()
self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
self.assert_sql_count(testbase.db, go, 1) # test passive_updates=True; update user
sess.clear()
assert User(username='jack', addresses=[Address(username='jack'), Address(username='jack')]) == sess.get(User, u1.id)
-
+ sess.clear()
+
u1 = sess.get(User, u1.id)
u1.addresses = []
u1.username = 'fred'
sess.flush()
sess.clear()
- assert sess.get(Address, a1.id).username is None
+ a1 = sess.get(Address, a1.id)
+ self.assertEquals(a1.username, None)
+
+ self.assertEquals(select([addresses.c.username]).execute().fetchall(), [(None,), (None,)])
+
u1 = sess.get(User, u1.id)
self.assertEquals(User(username='fred', fullname='jack'), u1)