From abc376c7025b53c790392026e0deaf305299ab6f Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 3 Apr 2006 21:43:22 +0000 Subject: [PATCH] added 'entity_name' keyword argument to mapper. a mapper is now associated with a class via the class object as well as the optional entity_name parameter, which is a string defaulting to None. any number of primary mappers can be created for a class, qualified by the entity name. instances of those classes will issue all of their load and save operations through their entity_name-qualified mapper, and maintain separate identity from an otherwise equilvalent object. --- lib/sqlalchemy/attributes.py | 5 + lib/sqlalchemy/mapping/mapper.py | 103 +++++++++++--------- lib/sqlalchemy/mapping/objectstore.py | 32 +++---- lib/sqlalchemy/mapping/properties.py | 2 + test/alltests.py | 1 + test/entity.py | 130 ++++++++++++++++++++++++++ 6 files changed, 213 insertions(+), 60 deletions(-) create mode 100644 test/entity.py diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index d1b0738de8..6424c9d793 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -72,6 +72,11 @@ class SmartProperty(object): h.append_nohistory(value) class ManagedAttribute(object): + """base class for a "managed attribute", which is attached to individual instances + of a class mapped to the keyname of the property, inside of a dictionary which is + attached to the object via the propertyname "_managed_attributes". Attribute access + which occurs through the SmartProperty property object ultimately calls upon + ManagedAttribute objects associated with the instance via this dictionary.""" def __init__(self, obj, key): self.__obj = weakref.ref(obj) self.key = key diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index d6b1fb015e..2db6e715a6 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -66,11 +66,12 @@ class Mapper(object): self.extension = ext self.class_ = class_ + self.entity_name = entity_name + self.class_key = ClassKey(class_, entity_name) self.is_primary = is_primary self.order_by = order_by self._options = {} self.always_refresh = always_refresh - self.entity_name = entity_name self.version_id_col = version_id_col if not issubclass(class_, object): @@ -208,12 +209,9 @@ class Mapper(object): for primary_key in self.pks_by_table[self.table]: self._get_clause.clauses.append(primary_key == sql.bindparam("pk_"+primary_key.key)) - if not mapper_registry.has_key(self.class_) or self.is_primary or (inherits is not None and inherits._is_primary_mapper()): + if not mapper_registry.has_key(self.class_key) or self.is_primary or (inherits is not None and inherits._is_primary_mapper()): objectstore.global_attributes.reset_class_managed(self.class_) self._init_class() - self.identitytable = self.primarytable - else: - self.identitytable = mapper_registry[self.class_].primarytable if inherits is not None: for key, prop in inherits.props.iteritems(): @@ -251,41 +249,50 @@ class Mapper(object): prop.init(key, self) def __str__(self): - return "Mapper|" + self.class_.__name__ + "|" + self.primarytable.name + return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + self.primarytable.name def _is_primary_mapper(self): - return mapper_registry.get(self.class_, None) is self + """returns True if this mapper is the primary mapper for its class key (class + entity_name)""" + return mapper_registry.get(self.class_key, None) is self def _primary_mapper(self): - return mapper_registry[self.class_] - + """returns the primary mapper corresponding to this mapper's class key (class + entity_name)""" + return mapper_registry[self.class_key] + + def is_assigned(self, instance): + """returns True if this mapper is the primary mapper for the given instance. this is dependent + not only on class assignment but the optional "entity_name" parameter as well.""" + return instance.__class__ is self.class_ and getattr(instance, '_entity_name', None) == self.entity_name + def _init_class(self): """sets up our classes' overridden __init__ method, this mappers hash key as its '_mapper' property, and our columns as its 'c' property. if the class already had a mapper, the old __init__ method is kept the same.""" - if not self.class_.__dict__.has_key('_mapper'): - oldinit = self.class_.__init__ - def init(self, *args, **kwargs): - # this gets the AttributeManager to do some pre-initialization, - # in order to save on KeyErrors later on - objectstore.global_attributes.init_attr(self) - - nohist = kwargs.pop('_mapper_nohistory', False) - session = kwargs.pop('_sa_session', objectstore.get_session()) - if not nohist: - # register new with the correct session, before the object's - # constructor is called, since further assignments within the - # constructor would otherwise bind it to whatever get_session() is. - session.register_new(self) - if oldinit is not None: - oldinit(self, *args, **kwargs) - # override oldinit, insuring that its not already one of our - # own modified inits - if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'): - init._sa_mapper_init = True - self.class_.__init__ = init - mapper_registry[self.class_] = self - self.class_.c = self.c + oldinit = self.class_.__init__ + def init(self, *args, **kwargs): + self._entity_name = kwargs.pop('_sa_entity_name', None) + + # this gets the AttributeManager to do some pre-initialization, + # in order to save on KeyErrors later on + objectstore.global_attributes.init_attr(self) + + nohist = kwargs.pop('_mapper_nohistory', False) + session = kwargs.pop('_sa_session', objectstore.get_session()) + if not nohist: + # register new with the correct session, before the object's + # constructor is called, since further assignments within the + # constructor would otherwise bind it to whatever get_session() is. + session.register_new(self) + if oldinit is not None: + oldinit(self, *args, **kwargs) + # override oldinit, insuring that its not already one of our + # own modified inits + if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'): + init._sa_mapper_init = True + self.class_.__init__ = init + mapper_registry[self.class_key] = self + if self.entity_name is None: + self.class_.c = self.c def set_property(self, key, prop): self.props[key] = prop @@ -325,7 +332,7 @@ class Mapper(object): """returns an instance of the object based on the given identifier, or None if not found. The *ident argument is a list of primary key columns in the order of the table def's primary key columns.""" - key = objectstore.get_id_key(ident, self.class_) + key = objectstore.get_id_key(ident, self.class_, self.entity_name) #print "key: " + repr(key) + " ident: " + repr(ident) return self._get(key, ident) @@ -352,7 +359,7 @@ class Mapper(object): def identity_key(self, *primary_key): """returns the instance key for the given identity value. this is a global tracking object used by the objectstore, and is usually available off a mapped object as instance._instance_key.""" - return objectstore.get_id_key(tuple(primary_key), self.class_) + return objectstore.get_id_key(tuple(primary_key), self.class_, self.entity_name) def instance_key(self, instance): """returns the instance key for the given instance. this is a global tracking object used by the objectstore, and is usually available off a mapped object as instance._instance_key.""" @@ -847,7 +854,7 @@ class Mapper(object): return statement def _identity_key(self, row): - return objectstore.get_row_key(row, self.class_, self.pks_by_table[self.table]) + return objectstore.get_row_key(row, self.class_, self.pks_by_table[self.table], self.entity_name) def _instance(self, row, imap, result = None, populate_existing = False): """pulls an object instance from the given row and appends it to the given result @@ -886,7 +893,7 @@ class Mapper(object): # plugin point instance = self.extension.create_instance(self, row, imap, self.class_) if instance is EXT_PASS: - instance = self.class_(_mapper_nohistory=True) + instance = self.class_(_mapper_nohistory=True, _sa_entity_name=self.entity_name) imap[identitykey] = instance isnew = True else: @@ -1086,7 +1093,16 @@ class MapperExtension(object): if self.next is not None: self.next.before_delete(mapper, instance) - +class ClassKey(object): + """keys a class and an entity name to a mapper, via the mapper_registry""" + def __init__(self, class_, entity_name): + self.class_ = class_ + self.entity_name = entity_name + def __hash__(self): + return hash((self.class_, self.entity_name)) + def __eq__(self, other): + return self.class_ is other.class_ and self.entity_name == other.entity_name + def hash_key(obj): if obj is None: return 'None' @@ -1100,11 +1116,14 @@ def hash_key(obj): def object_mapper(object): """given an object, returns the primary Mapper associated with the object or the object's class.""" - return class_mapper(object.__class__) + try: + return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', None))] + except KeyError: + raise InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', None))) -def class_mapper(class_): - """given a class, returns the primary Mapper associated with the class.""" +def class_mapper(class_, entity_name=None): + """given a ClassKey, returns the primary Mapper associated with the key.""" try: - return mapper_registry[class_] + return mapper_registry[ClassKey(class_, entity_name)] except (KeyError, AttributeError): - raise InvalidRequestError("Class '%s' has no mapper associated with it" % class_.__name__) + raise InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name)) diff --git a/lib/sqlalchemy/mapping/objectstore.py b/lib/sqlalchemy/mapping/objectstore.py index 7827c1b787..ee0470cde0 100644 --- a/lib/sqlalchemy/mapping/objectstore.py +++ b/lib/sqlalchemy/mapping/objectstore.py @@ -51,7 +51,7 @@ class Session(object): if self.__pushed_count == 0: for n in self.nest_on: n.pop_session() - def get_id_key(ident, class_): + def get_id_key(ident, class_, entity_name=None): """returns an identity-map key for use in storing/retrieving an item from the identity map, given a tuple of the object's primary key values. @@ -60,15 +60,12 @@ class Session(object): class_ - a reference to the object's class - table - a Table object where the object's primary fields are stored. - - selectable - a Selectable object which represents all the object's column-based fields. - this Selectable may be synonymous with the table argument or can be a larger construct - containing that table. return value: a tuple object which is used as an identity key. """ - return (class_, tuple(ident)) + entity_name - optional string name to further qualify the class + """ + return (class_, tuple(ident), entity_name) get_id_key = staticmethod(get_id_key) - def get_row_key(row, class_, primary_key): + def get_row_key(row, class_, primary_key, entity_name=None): """returns an identity-map key for use in storing/retrieving an item from the identity map, given a result set row. @@ -77,13 +74,12 @@ class Session(object): class_ - a reference to the object's class - table - a Table object where the object's primary fields are stored. - - selectable - a Selectable object which represents all the object's column-based fields. - this Selectable may be synonymous with the table argument or can be a larger construct - containing that table. return value: a tuple object which is used as an identity key. + primary_key - a list of column objects that will target the primary key values + in the given row. + + entity_name - optional string name to further qualify the class """ - return (class_, tuple([row[column] for column in primary_key])) + return (class_, tuple([row[column] for column in primary_key]), entity_name) get_row_key = staticmethod(get_row_key) class SessionTrans(object): @@ -222,11 +218,11 @@ class Session(object): u.register_new(instance) return instance -def get_id_key(ident, class_): - return Session.get_id_key(ident, class_) +def get_id_key(ident, class_, entity_name=None): + return Session.get_id_key(ident, class_, entity_name) -def get_row_key(row, class_, primary_key): - return Session.get_row_key(row, class_, primary_key) +def get_row_key(row, class_, primary_key, entity_name=None): + return Session.get_row_key(row, class_, primary_key, entity_name) def begin(): """begins a new UnitOfWork transaction. the next commit will affect only diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index 6d30c5ffe8..14e4749099 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -590,6 +590,8 @@ class LazyLoader(PropertyLoader): objectstore.global_attributes.register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private, callable_=lambda i: self.setup_loader(i), extension=self.attributeext) def setup_loader(self, instance): + if not self.parent.is_assigned(instance): + return object_mapper(instance).props[self.key].setup_loader(instance) def lazyload(): params = {} allparams = True diff --git a/test/alltests.py b/test/alltests.py index b266ebcb13..6b2d068cb5 100644 --- a/test/alltests.py +++ b/test/alltests.py @@ -44,6 +44,7 @@ def suite(): 'cycles', # more select/persistence, backrefs + 'entity', 'manytomany', 'onetoone', 'inheritance', diff --git a/test/entity.py b/test/entity.py new file mode 100644 index 0000000000..591cff7ea1 --- /dev/null +++ b/test/entity.py @@ -0,0 +1,130 @@ +from testbase import PersistTest, AssertMixin +import unittest +from sqlalchemy import * +import testbase + +from tables import * +import tables + +class EntityTest(AssertMixin): + """tests mappers that are constructed based on "entity names", which allows the same class + to have multiple primary mappers """ + def setUpAll(self): + global user1, user2, address1, address2 + db = testbase.db + user1 = Table('user1', db, + Column('user_id', Integer, Sequence('user1_id_seq'), primary_key=True), + Column('name', String(60), nullable=False) + ).create() + user2 = Table('user2', db, + Column('user_id', Integer, Sequence('user2_id_seq'), primary_key=True), + Column('name', String(60), nullable=False) + ).create() + address1 = Table('address1', db, + Column('address_id', Integer, Sequence('address1_id_seq'), primary_key=True), + Column('user_id', Integer, ForeignKey(user1.c.user_id), nullable=False), + Column('email', String(100), nullable=False) + ).create() + address2 = Table('address2', db, + Column('address_id', Integer, Sequence('address2_id_seq'), primary_key=True), + Column('user_id', Integer, ForeignKey(user2.c.user_id), nullable=False), + Column('email', String(100), nullable=False) + ).create() + def tearDownAll(self): + address1.drop() + address2.drop() + user1.drop() + user2.drop() + def tearDown(self): + address1.delete().execute() + address2.delete().execute() + user1.delete().execute() + user2.delete().execute() + objectstore.clear() + clear_mappers() + + def testbasic(self): + """tests a pair of one-to-many mapper structures, establishing that both + parent and child objects honor the "entity_name" attribute attached to the object + instances.""" + class User(object):pass + class Address(object):pass + + a1mapper = mapper(Address, address1, entity_name='address1') + a2mapper = mapper(Address, address2, entity_name='address2') + u1mapper = mapper(User, user1, entity_name='user1', properties ={ + 'addresses':relation(a1mapper) + }) + u2mapper =mapper(User, user2, entity_name='user2', properties={ + 'addresses':relation(a2mapper) + }) + + u1 = User(_sa_entity_name='user1') + u1.name = 'this is user 1' + a1 = Address(_sa_entity_name='address1') + a1.email='a1@foo.com' + u1.addresses.append(a1) + + u2 = User(_sa_entity_name='user2') + u2.name='this is user 2' + a2 = Address(_sa_entity_name='address2') + a2.email='a2@foo.com' + u2.addresses.append(a2) + + objectstore.commit() + assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] + assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] + assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')] + assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')] + + objectstore.clear() + u1list = u1mapper.select() + u2list = u2mapper.select() + assert len(u1list) == len(u2list) == 1 + assert u1list[0] is not u2list[0] + assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 + + def testpolymorphic(self): + """tests that entity_name can be used to have two kinds of relations on the same class.""" + class User(object):pass + class Address1(object):pass + class Address2(object):pass + + a1mapper = mapper(Address1, address1) + a2mapper = mapper(Address2, address2) + u1mapper = mapper(User, user1, entity_name='user1', properties ={ + 'addresses':relation(a1mapper) + }) + u2mapper =mapper(User, user2, entity_name='user2', properties={ + 'addresses':relation(a2mapper) + }) + + u1 = User(_sa_entity_name='user1') + u1.name = 'this is user 1' + a1 = Address1() + a1.email='a1@foo.com' + u1.addresses.append(a1) + + u2 = User(_sa_entity_name='user2') + u2.name='this is user 2' + a2 = Address2() + a2.email='a2@foo.com' + u2.addresses.append(a2) + + objectstore.commit() + assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] + assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] + assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')] + assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')] + + objectstore.clear() + u1list = u1mapper.select() + u2list = u2mapper.select() + assert len(u1list) == len(u2list) == 1 + assert u1list[0] is not u2list[0] + assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 + assert isinstance(u1list[0].addresses[0], Address1) + assert isinstance(u2list[0].addresses[0], Address2) + +if __name__ == "__main__": + testbase.main() -- 2.47.2