]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added 'entity_name' keyword argument to mapper. a mapper is now associated with...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Apr 2006 21:43:22 +0000 (21:43 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 3 Apr 2006 21:43:22 +0000 (21:43 +0000)
the class object as well as the optional entity_name parameter, which is a string defaulting to None.
any number of primary mappers can be created for a class, qualified by the entity name.  instances of those classes
will issue all of their load and save operations through their entity_name-qualified mapper, and maintain separate identity from an otherwise equilvalent object.

lib/sqlalchemy/attributes.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/objectstore.py
lib/sqlalchemy/mapping/properties.py
test/alltests.py
test/entity.py [new file with mode: 0644]

index d1b0738de8745ee95d5be711b7d1e6f7dc9c269e..6424c9d7938c00f0e08074bff33fea44dd8958ee 100644 (file)
@@ -72,6 +72,11 @@ class SmartProperty(object):
         h.append_nohistory(value)
 
 class ManagedAttribute(object):
+    """base class for a "managed attribute", which is attached to individual instances
+    of a class mapped to the keyname of the property, inside of a dictionary which is
+    attached to the object via the propertyname "_managed_attributes".  Attribute access
+    which occurs through the SmartProperty property object ultimately calls upon 
+    ManagedAttribute objects associated with the instance via this dictionary."""
     def __init__(self, obj, key):
         self.__obj = weakref.ref(obj)
         self.key = key
index d6b1fb015e578c905f319b2476d301df1f4561f8..2db6e715a6fd7ef12dee28394c4edd35ed2dff79 100644 (file)
@@ -66,11 +66,12 @@ class Mapper(object):
         self.extension = ext
 
         self.class_ = class_
+        self.entity_name = entity_name
+        self.class_key = ClassKey(class_, entity_name)
         self.is_primary = is_primary
         self.order_by = order_by
         self._options = {}
         self.always_refresh = always_refresh
-        self.entity_name = entity_name
         self.version_id_col = version_id_col
         
         if not issubclass(class_, object):
@@ -208,12 +209,9 @@ class Mapper(object):
         for primary_key in self.pks_by_table[self.table]:
             self._get_clause.clauses.append(primary_key == sql.bindparam("pk_"+primary_key.key))
 
-        if not mapper_registry.has_key(self.class_) or self.is_primary or (inherits is not None and inherits._is_primary_mapper()):
+        if not mapper_registry.has_key(self.class_key) or self.is_primary or (inherits is not None and inherits._is_primary_mapper()):
             objectstore.global_attributes.reset_class_managed(self.class_)
             self._init_class()
-            self.identitytable = self.primarytable
-        else:
-            self.identitytable = mapper_registry[self.class_].primarytable
                 
         if inherits is not None:
             for key, prop in inherits.props.iteritems():
@@ -251,41 +249,50 @@ class Mapper(object):
         prop.init(key, self)
         
     def __str__(self):
-        return "Mapper|" + self.class_.__name__ + "|" + self.primarytable.name
+        return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + self.primarytable.name
     
     def _is_primary_mapper(self):
-        return mapper_registry.get(self.class_, None) is self
+        """returns True if this mapper is the primary mapper for its class key (class + entity_name)"""
+        return mapper_registry.get(self.class_key, None) is self
 
     def _primary_mapper(self):
-        return mapper_registry[self.class_]
-        
+        """returns the primary mapper corresponding to this mapper's class key (class + entity_name)"""
+        return mapper_registry[self.class_key]
+
+    def is_assigned(self, instance):
+        """returns True if this mapper is the primary mapper for the given instance.  this is dependent
+        not only on class assignment but the optional "entity_name" parameter as well."""
+        return instance.__class__ is self.class_ and getattr(instance, '_entity_name', None) == self.entity_name
+
     def _init_class(self):
         """sets up our classes' overridden __init__ method, this mappers hash key as its
         '_mapper' property, and our columns as its 'c' property.  if the class already had a
         mapper, the old __init__ method is kept the same."""
-        if not self.class_.__dict__.has_key('_mapper'):
-            oldinit = self.class_.__init__
-            def init(self, *args, **kwargs):
-                # this gets the AttributeManager to do some pre-initialization,
-                # in order to save on KeyErrors later on
-                objectstore.global_attributes.init_attr(self)
-                
-                nohist = kwargs.pop('_mapper_nohistory', False)
-                session = kwargs.pop('_sa_session', objectstore.get_session())
-                if not nohist:
-                    # register new with the correct session, before the object's 
-                    # constructor is called, since further assignments within the
-                    # constructor would otherwise bind it to whatever get_session() is.
-                    session.register_new(self)
-                if oldinit is not None:
-                    oldinit(self, *args, **kwargs)
-            # override oldinit, insuring that its not already one of our
-            # own modified inits
-            if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'):
-                init._sa_mapper_init = True
-                self.class_.__init__ = init
-        mapper_registry[self.class_] = self
-        self.class_.c = self.c
+        oldinit = self.class_.__init__
+        def init(self, *args, **kwargs):
+            self._entity_name = kwargs.pop('_sa_entity_name', None)
+
+            # this gets the AttributeManager to do some pre-initialization,
+            # in order to save on KeyErrors later on
+            objectstore.global_attributes.init_attr(self)
+            
+            nohist = kwargs.pop('_mapper_nohistory', False)
+            session = kwargs.pop('_sa_session', objectstore.get_session())
+            if not nohist:
+                # register new with the correct session, before the object's 
+                # constructor is called, since further assignments within the
+                # constructor would otherwise bind it to whatever get_session() is.
+                session.register_new(self)
+            if oldinit is not None:
+                oldinit(self, *args, **kwargs)
+        # override oldinit, insuring that its not already one of our
+        # own modified inits
+        if oldinit is None or not hasattr(oldinit, '_sa_mapper_init'):
+            init._sa_mapper_init = True
+            self.class_.__init__ = init
+        mapper_registry[self.class_key] = self
+        if self.entity_name is None:
+            self.class_.c = self.c
         
     def set_property(self, key, prop):
         self.props[key] = prop
@@ -325,7 +332,7 @@ class Mapper(object):
         """returns an instance of the object based on the given identifier, or None
         if not found.  The *ident argument is a 
         list of primary key columns in the order of the table def's primary key columns."""
-        key = objectstore.get_id_key(ident, self.class_)
+        key = objectstore.get_id_key(ident, self.class_, self.entity_name)
         #print "key: " + repr(key) + " ident: " + repr(ident)
         return self._get(key, ident)
         
@@ -352,7 +359,7 @@ class Mapper(object):
         
     def identity_key(self, *primary_key):
         """returns the instance key for the given identity value.  this is a global tracking object used by the objectstore, and is usually available off a mapped object as instance._instance_key."""
-        return objectstore.get_id_key(tuple(primary_key), self.class_)
+        return objectstore.get_id_key(tuple(primary_key), self.class_, self.entity_name)
     
     def instance_key(self, instance):
         """returns the instance key for the given instance.  this is a global tracking object used by the objectstore, and is usually available off a mapped object as instance._instance_key."""
@@ -847,7 +854,7 @@ class Mapper(object):
         return statement
         
     def _identity_key(self, row):
-        return objectstore.get_row_key(row, self.class_, self.pks_by_table[self.table])
+        return objectstore.get_row_key(row, self.class_, self.pks_by_table[self.table], self.entity_name)
 
     def _instance(self, row, imap, result = None, populate_existing = False):
         """pulls an object instance from the given row and appends it to the given result
@@ -886,7 +893,7 @@ class Mapper(object):
             # plugin point
             instance = self.extension.create_instance(self, row, imap, self.class_)
             if instance is EXT_PASS:
-                instance = self.class_(_mapper_nohistory=True)
+                instance = self.class_(_mapper_nohistory=True, _sa_entity_name=self.entity_name)
             imap[identitykey] = instance
             isnew = True
         else:
@@ -1086,7 +1093,16 @@ class MapperExtension(object):
         if self.next is not None:
             self.next.before_delete(mapper, instance)
 
-        
+class ClassKey(object):
+    """keys a class and an entity name to a mapper, via the mapper_registry"""
+    def __init__(self, class_, entity_name):
+        self.class_ = class_
+        self.entity_name = entity_name
+    def __hash__(self):
+        return hash((self.class_, self.entity_name))
+    def __eq__(self, other):
+        return self.class_ is other.class_ and self.entity_name == other.entity_name
+            
 def hash_key(obj):
     if obj is None:
         return 'None'
@@ -1100,11 +1116,14 @@ def hash_key(obj):
 def object_mapper(object):
     """given an object, returns the primary Mapper associated with the object
     or the object's class."""
-    return class_mapper(object.__class__)
+    try:
+        return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', None))]
+    except KeyError:
+        raise InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', None)))
 
-def class_mapper(class_):
-    """given a class, returns the primary Mapper associated with the class."""
+def class_mapper(class_, entity_name=None):
+    """given a ClassKey, returns the primary Mapper associated with the key."""
     try:
-        return mapper_registry[class_]
+        return mapper_registry[ClassKey(class_, entity_name)]
     except (KeyError, AttributeError):
-        raise InvalidRequestError("Class '%s' has no mapper associated with it" % class_.__name__)
+        raise InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name))
index 7827c1b7874ac1bd760200d009a4bd726a86d561..ee0470cde06ca4a5b76ef845e7d26c52da2d6a82 100644 (file)
@@ -51,7 +51,7 @@ class Session(object):
         if self.__pushed_count == 0:
             for n in self.nest_on:
                 n.pop_session()
-    def get_id_key(ident, class_):
+    def get_id_key(ident, class_, entity_name=None):
         """returns an identity-map key for use in storing/retrieving an item from the identity
         map, given a tuple of the object's primary key values.
 
@@ -60,15 +60,12 @@ class Session(object):
 
         class_ - a reference to the object's class
 
-        table - a Table object where the object's primary fields are stored.
-
-        selectable - a Selectable object which represents all the object's column-based fields.
-        this Selectable may be synonymous with the table argument or can be a larger construct
-        containing that table. return value: a tuple object which is used as an identity key. """
-        return (class_, tuple(ident))
+        entity_name - optional string name to further qualify the class
+        """
+        return (class_, tuple(ident), entity_name)
     get_id_key = staticmethod(get_id_key)
 
-    def get_row_key(row, class_, primary_key):
+    def get_row_key(row, class_, primary_key, entity_name=None):
         """returns an identity-map key for use in storing/retrieving an item from the identity
         map, given a result set row.
 
@@ -77,13 +74,12 @@ class Session(object):
 
         class_ - a reference to the object's class
 
-        table - a Table object where the object's primary fields are stored.
-
-        selectable - a Selectable object which represents all the object's column-based fields.
-        this Selectable may be synonymous with the table argument or can be a larger construct
-        containing that table. return value: a tuple object which is used as an identity key.
+        primary_key - a list of column objects that will target the primary key values
+        in the given row.
+        
+        entity_name - optional string name to further qualify the class
         """
-        return (class_, tuple([row[column] for column in primary_key]))
+        return (class_, tuple([row[column] for column in primary_key]), entity_name)
     get_row_key = staticmethod(get_row_key)
 
     class SessionTrans(object):
@@ -222,11 +218,11 @@ class Session(object):
             u.register_new(instance)
         return instance
 
-def get_id_key(ident, class_):
-    return Session.get_id_key(ident, class_)
+def get_id_key(ident, class_, entity_name=None):
+    return Session.get_id_key(ident, class_, entity_name)
 
-def get_row_key(row, class_, primary_key):
-    return Session.get_row_key(row, class_, primary_key)
+def get_row_key(row, class_, primary_key, entity_name=None):
+    return Session.get_row_key(row, class_, primary_key, entity_name)
 
 def begin():
     """begins a new UnitOfWork transaction.  the next commit will affect only
index 6d30c5ffe8d0d0dc05c55d97b3885cc96daf697a..14e4749099268d81f7c3c7c20deb0fa846371694 100644 (file)
@@ -590,6 +590,8 @@ class LazyLoader(PropertyLoader):
         objectstore.global_attributes.register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private, callable_=lambda i: self.setup_loader(i), extension=self.attributeext)
 
     def setup_loader(self, instance):
+        if not self.parent.is_assigned(instance):
+            return object_mapper(instance).props[self.key].setup_loader(instance)
         def lazyload():
             params = {}
             allparams = True
index b266ebcb138383b3fa7eadf3f0d78207dd0a1eb7..6b2d068cb595d87c741a4c37919ff762544bc8ed 100644 (file)
@@ -44,6 +44,7 @@ def suite():
         'cycles',
         
         # more select/persistence, backrefs
+        'entity',
         'manytomany',
         'onetoone',
         'inheritance',
diff --git a/test/entity.py b/test/entity.py
new file mode 100644 (file)
index 0000000..591cff7
--- /dev/null
@@ -0,0 +1,130 @@
+from testbase import PersistTest, AssertMixin
+import unittest
+from sqlalchemy import *
+import testbase
+
+from tables import *
+import tables
+
+class EntityTest(AssertMixin):
+    """tests mappers that are constructed based on "entity names", which allows the same class
+    to have multiple primary mappers """
+    def setUpAll(self):
+        global user1, user2, address1, address2
+        db = testbase.db
+        user1 = Table('user1', db, 
+            Column('user_id', Integer, Sequence('user1_id_seq'), primary_key=True),
+            Column('name', String(60), nullable=False)
+            ).create()
+        user2 = Table('user2', db, 
+            Column('user_id', Integer, Sequence('user2_id_seq'), primary_key=True),
+            Column('name', String(60), nullable=False)
+            ).create()
+        address1 = Table('address1', db,
+            Column('address_id', Integer, Sequence('address1_id_seq'), primary_key=True),
+            Column('user_id', Integer, ForeignKey(user1.c.user_id), nullable=False),
+            Column('email', String(100), nullable=False)
+            ).create()
+        address2 = Table('address2', db,
+            Column('address_id', Integer, Sequence('address2_id_seq'), primary_key=True),
+            Column('user_id', Integer, ForeignKey(user2.c.user_id), nullable=False),
+            Column('email', String(100), nullable=False)
+            ).create()
+    def tearDownAll(self):
+        address1.drop()
+        address2.drop()
+        user1.drop()
+        user2.drop()
+    def tearDown(self):
+        address1.delete().execute()
+        address2.delete().execute()
+        user1.delete().execute()
+        user2.delete().execute()
+        objectstore.clear()
+        clear_mappers()
+
+    def testbasic(self):
+        """tests a pair of one-to-many mapper structures, establishing that both
+        parent and child objects honor the "entity_name" attribute attached to the object
+        instances."""
+        class User(object):pass
+        class Address(object):pass
+            
+        a1mapper = mapper(Address, address1, entity_name='address1')
+        a2mapper = mapper(Address, address2, entity_name='address2')    
+        u1mapper = mapper(User, user1, entity_name='user1', properties ={
+            'addresses':relation(a1mapper)
+        })
+        u2mapper =mapper(User, user2, entity_name='user2', properties={
+            'addresses':relation(a2mapper)
+        })
+        
+        u1 = User(_sa_entity_name='user1')
+        u1.name = 'this is user 1'
+        a1 = Address(_sa_entity_name='address1')
+        a1.email='a1@foo.com'
+        u1.addresses.append(a1)
+        
+        u2 = User(_sa_entity_name='user2')
+        u2.name='this is user 2'
+        a2 = Address(_sa_entity_name='address2')
+        a2.email='a2@foo.com'
+        u2.addresses.append(a2)
+        
+        objectstore.commit()
+        assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
+        assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
+        assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')]
+        assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')]
+
+        objectstore.clear()
+        u1list = u1mapper.select()
+        u2list = u2mapper.select()
+        assert len(u1list) == len(u2list) == 1
+        assert u1list[0] is not u2list[0]
+        assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
+
+    def testpolymorphic(self):
+        """tests that entity_name can be used to have two kinds of relations on the same class."""
+        class User(object):pass
+        class Address1(object):pass
+        class Address2(object):pass
+            
+        a1mapper = mapper(Address1, address1)
+        a2mapper = mapper(Address2, address2)    
+        u1mapper = mapper(User, user1, entity_name='user1', properties ={
+            'addresses':relation(a1mapper)
+        })
+        u2mapper =mapper(User, user2, entity_name='user2', properties={
+            'addresses':relation(a2mapper)
+        })
+
+        u1 = User(_sa_entity_name='user1')
+        u1.name = 'this is user 1'
+        a1 = Address1()
+        a1.email='a1@foo.com'
+        u1.addresses.append(a1)
+
+        u2 = User(_sa_entity_name='user2')
+        u2.name='this is user 2'
+        a2 = Address2()
+        a2.email='a2@foo.com'
+        u2.addresses.append(a2)
+
+        objectstore.commit()
+        assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
+        assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
+        assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')]
+        assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')]
+
+        objectstore.clear()
+        u1list = u1mapper.select()
+        u2list = u2mapper.select()
+        assert len(u1list) == len(u2list) == 1
+        assert u1list[0] is not u2list[0]
+        assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
+        assert isinstance(u1list[0].addresses[0], Address1)
+        assert isinstance(u2list[0].addresses[0], Address2)
+        
+if __name__ == "__main__":    
+    testbase.main()