c.employees.append(Engineer(name='wally', status='CGG', engineer_name='engineer2', primary_language='python'))
c.employees.append(Manager(name='jsmith', status='ABA', manager_name='manager2'))
session.save(c)
+
print session.new
session.flush()
#sys.exit()
and assocites this Mapper with its class via the mapper_registry."""
oldinit = self.class_.__init__
def init(self, *args, **kwargs):
- self._entity_name = kwargs.pop('_sa_entity_name', None)
# this gets the AttributeManager to do some pre-initialization,
# in order to save on KeyErrors later on
sessionlib.global_attributes.init_attr(self)
+ entity_name = kwargs.pop('_sa_entity_name', None)
if kwargs.has_key('_sa_session'):
session = kwargs.pop('_sa_session')
else:
# works for whatever mapper the class is associated with
- mapper = mapper_registry.get(ClassKey(self.__class__, self._entity_name))
+ mapper = mapper_registry.get(ClassKey(self.__class__, entity_name))
if mapper is not None:
session = mapper.extension.get_session()
if session is EXT_PASS:
session = None
else:
session = None
+ # if a session was found, either via _sa_session or via mapper extension,
+ # save() this instance to the session, and give it an associated entity_name.
+ # otherwise, this instance will not have a session or mapper association until it is
+ # save()d to some session.
if session is not None:
+ self._entity_name = entity_name
session._register_new(self)
+
if oldinit is not None:
oldinit(self, *args, **kwargs)
# override oldinit, insuring that its not already a Mapper-decorated init method
for prop in self.props.values():
prop.register_dependencies(uowcommit, *args, **kwargs)
- def cascade_iterator(self, type, object, recursive=None):
+ def cascade_iterator(self, type, object, callable_=None, recursive=None):
if recursive is None:
recursive=sets.Set()
- if object not in recursive:
- recursive.add(object)
- yield object
for prop in self.props.values():
for c in prop.cascade_iterator(type, object, recursive):
yield c
+ def cascade_callable(self, type, object, callable_, recursive=None):
+ if recursive is None:
+ recursive=sets.Set()
+ for prop in self.props.values():
+ prop.cascade_callable(type, object, callable_, recursive)
+
def _row_identity_key(self, row):
return sessionlib.get_row_key(row, self.class_, self.pks_by_table[self.mapped_table], self.entity_name)
raise NotImplementedError()
def cascade_iterator(self, type, object, recursive=None):
return []
+ def cascade_callable(self, type, object, callable_, recursive=None):
+ return []
def copy(self):
raise NotImplementedError()
def get_criterion(self, query, key, value):
return obj.hash_key()
else:
return repr(obj)
+
+def has_mapper(object):
+ """returns True if the given object has a mapper association"""
+ return hasattr(object, '_entity_name')
-def object_mapper(object, raiseerror=True, entity_name=None):
- """given an object, returns the primary Mapper associated with the object
- or the object's class."""
+def object_mapper(object, raiseerror=True):
+ """given an object, returns the primary Mapper associated with the object instance"""
try:
- return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', entity_name))]
- except KeyError:
+ return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name'))]
+ except (KeyError, AttributeError):
if raiseerror:
raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', None)))
else:
if not type in self.cascade:
return
childlist = sessionlib.global_attributes.get_history(object, self.key, passive=True)
-
+
+ mapper = self.mapper.primary_mapper()
for c in childlist.added_items() + childlist.deleted_items() + childlist.unchanged_items():
- if c is not None:
- if c not in recursive:
- recursive.add(c)
- yield c
- for c2 in self.mapper.primary_mapper().cascade_iterator(type, c, recursive):
- yield c2
+ if c is not None and c not in recursive:
+ recursive.add(c)
+ yield c
+ for c2 in mapper.cascade_iterator(type, c, recursive):
+ yield c2
+
+ def cascade_callable(self, type, object, callable_, recursive):
+ if not type in self.cascade:
+ return
+ childlist = sessionlib.global_attributes.get_history(object, self.key, passive=True)
+ mapper = self.mapper.primary_mapper()
+ for c in childlist.added_items() + childlist.deleted_items() + childlist.unchanged_items():
+ if c is not None and c not in recursive:
+ recursive.add(c)
+ callable_(c, mapper.entity_name)
+ mapper.cascade_callable(type, c, callable_, recursive)
def copy(self):
x = self.__class__.__new__(self.__class__)
raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (key, parent.class_.__name__, parent.class_.__name__))
self.do_init_subclass(key, parent)
+
+ def _register_attribute(self, class_, callable_=None):
+ sessionlib.global_attributes.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, callable_=callable_)
+
+ def _create_history(self, instance, callable_=None):
+ return sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True, callable_=callable_)
def _set_class_attribute(self, class_, key):
"""sets attribute behavior on our target class."""
- sessionlib.global_attributes.register_attribute(class_, key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True)
+ self._register_attribute(class_)
def _get_direction(self):
"""determines our 'direction', i.e. do we represent one to many, many to many, etc."""
if self.is_primary():
return
#print "PLAIN PROPLOADER EXEC NON-PRIAMRY", repr(id(self)), repr(self.mapper.class_), self.key
- sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True)
+ self._create_history(instance)
def register_dependencies(self, uowcommit):
self._dependency_processor.register_dependencies(uowcommit)
def _set_class_attribute(self, class_, key):
# establish a class-level lazy loader on our class
#print "SETCLASSATTR LAZY", repr(class_), key
- sessionlib.global_attributes.register_attribute(class_, key, uselist = self.uselist, callable_=lambda i: self.setup_loader(i), extension=self.attributeext, cascade=self.cascade, trackparent=True)
+ self._register_attribute(class_, callable_=lambda i: self.setup_loader(i))
def setup_loader(self, instance):
+ # make sure our parent mapper is the one thats assigned to this instance, else call that one
if not self.localparent.is_assigned(instance):
- return mapper.object_mapper(instance).props[self.key].setup_loader(instance)
+ # if no mapper association with this instance (i.e. not in a session, not loaded by a mapper),
+ # then we cant set up a lazy loader
+ if not mapper.has_mapper(instance):
+ return None
+ else:
+ return mapper.object_mapper(instance).props[self.key].setup_loader(instance)
def lazyload():
params = {}
allparams = True
#print "EXEC NON-PRIAMRY", repr(self.mapper.class_), self.key
# we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
# which will override the class-level behavior
- sessionlib.global_attributes.create_history(instance, self.key, self.uselist, callable_=self.setup_loader(instance), cascade=self.cascade, trackparent=True)
+ self._create_history(instance, callable_=self.setup_loader(instance))
else:
#print "EXEC PRIMARY", repr(self.mapper.class_), self.key
# we are the primary manager for this attribute on this class - reset its per-instance attribute state,
if isnew:
# new row loaded from the database. initialize a blank container on the instance.
# this will override any per-class lazyloading type of stuff.
- h = sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True)
+ h = self._create_history(instance)
if not self.uselist:
if isnew:
The 'entity_name' keyword argument will further qualify the specific Mapper used to handle this
instance.
"""
- for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object):
- if c is object:
- self._save_impl(c, entity_name=entity_name)
- else:
- # TODO: this is running the cascade rules twice
- self.save_or_update(c, entity_name=entity_name)
+ self._save_impl(object, entity_name=entity_name)
+ object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
def update(self, object, entity_name=None):
"""Brings the given detached (saved) instance into this Session.
Session), an exception is thrown.
This operation cascades the "save_or_update" method to associated instances if the relation is mapped
with cascade="save-update"."""
- for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object):
- if c is object:
- self._update_impl(c, entity_name=entity_name)
- else:
- self.save_or_update(c, entity_name=entity_name)
+ self._update_impl(object, entity_name=entity_name)
+ object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
def save_or_update(self, object, entity_name=None):
- for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object):
- key = getattr(object, '_instance_key', None)
- if key is None:
- self._save_impl(c, entity_name=entity_name)
- else:
- self._update_impl(c, entity_name=entity_name)
-
+ self._save_or_update_impl(object, entity_name=entity_name)
+ object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
+
+ def _save_or_update_impl(self, object, entity_name=None):
+ key = getattr(object, '_instance_key', None)
+ if key is None:
+ self._save_impl(object, entity_name=entity_name)
+ else:
+ self._update_impl(object, entity_name=entity_name)
+
def delete(self, object, entity_name=None):
- for c in object_mapper(object, entity_name=entity_name).cascade_iterator('delete', object):
+ #self.uow.register_deleted(object)
+ for c in [object] + list(object_mapper(object).cascade_iterator('delete', object)):
self.uow.register_deleted(c)
def merge(self, object, entity_name=None):
instance = None
- for obj in object_mapper(object, entity_name=entity_name).cascade_iterator('merge', object):
+ for obj in [object] + list(object_mapper(object).cascade_iterator('merge', object)):
key = getattr(obj, '_instance_key', None)
if key is None:
- mapper = object_mapper(object, entity_name=entity_name)
+ mapper = object_mapper(object)
ident = mapper.identity(object)
for k in ident:
if k is None:
if not self.uow.has_key(object._instance_key):
raise exceptions.InvalidRequestError("Instance '%s' is already persistent in a different Session" % repr(object))
else:
- entity_name = kwargs.get('entity_name', None)
- if entity_name is not None:
- m = class_mapper(object.__class__, entity_name=entity_name)
- m._assign_entity_name(object)
+ m = class_mapper(object.__class__, entity_name=kwargs.get('entity_name', None))
+ m._assign_entity_name(object)
self._register_new(object)
def _update_impl(self, object, **kwargs):
def get_row_key(row, class_, primary_key, entity_name=None):
return Session.get_row_key(row, class_, primary_key, entity_name)
-def object_mapper(obj, **kwargs):
- return sqlalchemy.orm.object_mapper(obj, **kwargs)
+def object_mapper(obj):
+ return sqlalchemy.orm.object_mapper(obj)
def class_mapper(class_, **kwargs):
return sqlalchemy.orm.class_mapper(class_, **kwargs)
if self.cascade is not None:
if not isdelete:
if self.cascade.save_update:
- sess.save_or_update(item)
+ # cascade the save_update operation onto the child object,
+ # relative to the mapper handling the parent object
+ # TODO: easier way to do this ?
+ mapper = object_mapper(obj)
+ prop = mapper.props[self.key]
+ ename = prop.mapper.entity_name
+ sess.save_or_update(item, entity_name=ename)
def append(self, item, _mapper_nohistory = False):
if _mapper_nohistory:
self.append_nohistory(item)
sess._register_changed(obj)
if newvalue is not None and self.cascade is not None:
if self.cascade.save_update:
- sess.save_or_update(newvalue)
+ # cascade the save_update operation onto the child object,
+ # relative to the mapper handling the parent object
+ # TODO: easier way to do this ?
+ mapper = object_mapper(obj)
+ prop = mapper.props[self.key]
+ ename = prop.mapper.entity_name
+ sess.save_or_update(newvalue, entity_name=ename)
class UOWAttributeManager(attributes.AttributeManager):
"""overrides AttributeManager to provide unit-of-work "dirty" hooks when scalar attribues are modified, plus factory methods for UOWProperrty/UOWListElement."""
def __init__(self):
attributes.AttributeManager.__init__(self)
-
+
def create_prop(self, class_, key, uselist, callable_, **kwargs):
return UOWProperty(class_, self, key, uselist, callable_, **kwargs)
assert u1list[0] is not u2list[0]
assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
+ def testcascade(self):
+ """same as testbasic but relies on session cascading"""
+ class User(object):pass
+ class Address(object):pass
+
+ a1mapper = mapper(Address, address1, entity_name='address1')
+ a2mapper = mapper(Address, address2, entity_name='address2')
+ u1mapper = mapper(User, user1, entity_name='user1', properties ={
+ 'addresses':relation(a1mapper)
+ })
+ u2mapper =mapper(User, user2, entity_name='user2', properties={
+ 'addresses':relation(a2mapper)
+ })
+
+ sess = create_session()
+ u1 = User()
+ u1.name = 'this is user 1'
+ sess.save(u1, entity_name='user1')
+ a1 = Address()
+ a1.email='a1@foo.com'
+ u1.addresses.append(a1)
+
+ u2 = User()
+ u2.name='this is user 2'
+ a2 = Address()
+ a2.email='a2@foo.com'
+ u2.addresses.append(a2)
+ sess.save(u2, entity_name='user2')
+
+ sess.flush()
+ assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
+ assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
+ assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')]
+ assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')]
+
+ sess.clear()
+ u1list = sess.query(User, entity_name='user1').select()
+ u2list = sess.query(User, entity_name='user2').select()
+ assert len(u1list) == len(u2list) == 1
+ assert u1list[0] is not u2list[0]
+ assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
+
def testpolymorphic(self):
"""tests that entity_name can be used to have two kinds of relations on the same class."""
class User(object):pass
class Location(BaseObject):
def __repr__(self):
- return "%s(%s, %s)" % (self.__class__.__name__, repr(self.issue_id), repr(str(self._name.name)))
+ return "%s(%s, %s)" % (self.__class__.__name__, str(self.issue_id), repr(str(self._name.name)))
def _get_name(self):
return self._name
ctx.current.flush([a])
ctx.current.delete(a)
+ print ctx.current.deleted
ctx.current.flush([a])
+# ctx.current.flush()
assert b_table.count().scalar() == 0