]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- moved test/orm/fixtures.py to testlib
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Aug 2007 22:15:15 +0000 (22:15 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Aug 2007 22:15:15 +0000 (22:15 +0000)
- flattened mapper calls in _instance() to operate directly
through a default MapperExtension
- more tests for ScopedSession, fixed [ticket:746]
- threadlocal engine propagates **kwargs through begin()

15 files changed:
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
test/orm/assorted_eager.py
test/orm/dynamic.py
test/orm/eager_relations.py
test/orm/inheritance/selects.py
test/orm/lazy_relations.py
test/orm/query.py
test/orm/selectable.py
test/orm/session.py
test/orm/unitofwork.py
test/testlib/fixtures.py [moved from test/orm/fixtures.py with 95% similarity]

index b6ba54ea587ed945b3144d0b7d8c382c71954690..4b251de13dfe4f3102eec777c0af44dbb9d6447f 100644 (file)
@@ -31,10 +31,10 @@ class TLSession(object):
     def in_transaction(self):
         return self.__tcount > 0
 
-    def begin(self):
+    def begin(self, **kwargs):
         if self.__tcount == 0:
             self.__transaction = self.get_connection()
-            self.__trans = self.__transaction._begin()
+            self.__trans = self.__transaction._begin(**kwargs)
         self.__tcount += 1
         return self.__trans
 
@@ -75,8 +75,8 @@ class TLConnection(base.Connection):
     def in_transaction(self):
         return self.session.in_transaction()
 
-    def begin(self):
-        return self.session.begin()
+    def begin(self, **kwargs):
+        return self.session.begin(**kwargs)
 
     def close(self):
         if self.__opencount == 1:
@@ -143,8 +143,8 @@ class TLEngine(base.Engine):
 
         return self.session.get_connection(**kwargs)
 
-    def begin(self):
-        return self.session.begin()
+    def begin(self, **kwargs):
+        return self.session.begin(**kwargs)
 
     def commit(self):
         self.session.commit()
index 74b184a7c254023ae2691b18daa3122d43aa2af8..c54eee4381bb7251506a96d3e0e47a48506bfcbc 100644 (file)
@@ -516,7 +516,7 @@ class ExtensionOption(MapperOption):
 
     def process_query(self, query):
         query._extension = query._extension.copy()
-        query._extension.append(self.ext)
+        query._extension.insert(self.ext)
 
 class SynonymProperty(MapperProperty):
     def __init__(self, name, proxy=False):
index 014d593d0ccc3788d222a984b4affb28e1813d65..db22015bfd32e5303be840b25a697dd1f8bc07cb 100644 (file)
@@ -10,7 +10,7 @@ from sqlalchemy import sql_util as sqlutil
 from sqlalchemy.orm import util as mapperutil
 from sqlalchemy.orm.util import ExtensionCarrier
 from sqlalchemy.orm import sync
-from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, MapperExtension, SynonymProperty
+from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, EXT_STOP, MapperExtension, SynonymProperty
 deferred_load = None
 
 __all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry']
@@ -287,20 +287,22 @@ class Mapper(object):
         creates a linked list of those extensions.
         """
 
-        extlist = util.Set()
-        for ext_class in global_extensions:
-            if isinstance(ext_class, MapperExtension):
-                extlist.add(ext_class)
-            else:
-                extlist.add(ext_class())
-            # local MapperExtensions have already instrumented the class
-            extlist[-1].instrument_class(self, self.class_)
-            
+        extlist = util.OrderedSet()
+
         extension = self.extension
         if extension is not None:
             for ext_obj in util.to_list(extension):
+                # local MapperExtensions have already instrumented the class
                 extlist.add(ext_obj)
 
+        for ext in global_extensions:
+            if isinstance(ext_class, type):
+                ext = ext()
+            extlist.add(ext)
+            ext.instrument_class(self, self.class_)
+            
+        extlist.add(_DefaultExtension())
+        
         self.extension = ExtensionCarrier()
         for ext in extlist:
             self.extension.append(ext)
@@ -1387,9 +1389,7 @@ class Mapper(object):
         else:
             extension = self.extension
 
-        ret = extension.translate_row(self, context, row)
-        if ret is not EXT_CONTINUE:
-            row = ret
+        row = extension.translate_row(self, context, row)
 
         if not skip_polymorphic and self.polymorphic_on is not None:
             discriminator = row[self.polymorphic_on]
@@ -1407,7 +1407,7 @@ class Mapper(object):
 
         identitykey = self.identity_key_from_row(row)
         populate_existing = context.populate_existing or self.always_refresh
-        if context.session.has_key(identitykey):
+        if identitykey in context.session.identity_map:
             instance = context.session._get(identitykey)
             if self.__should_log_debug:
                 self.__log_debug("_instance(): using existing instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
@@ -1419,11 +1419,8 @@ class Mapper(object):
                 if identitykey not in context.identity_map:
                     context.identity_map[identitykey] = instance
                     isnew = True
-                if extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
-                    self.populate_instance(context, instance, row, instancekey=identitykey, isnew=isnew)
-            if extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
-                if result is not None:
-                    result.append(instance)
+                extension.populate_instance(self, context, row, instance, instancekey=identitykey, isnew=isnew)
+            extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew)
             return instance
         else:
             if self.__should_log_debug:
@@ -1447,10 +1444,7 @@ class Mapper(object):
 
             # plugin point
             instance = extension.create_instance(self, context, row, self.class_)
-            if instance is EXT_CONTINUE:
-                instance = self._create_instance(context.session)
-            else:
-                instance._entity_name = self.entity_name
+            instance._entity_name = self.entity_name
             if self.__should_log_debug:
                 self.__log_debug("_instance(): created new instance %s identity %s" % (mapperutil.instance_str(instance), str(identitykey)))
             context.identity_map[identitykey] = instance
@@ -1462,19 +1456,10 @@ class Mapper(object):
         # call further mapper properties on the row, to pull further
         # instances from the row and possibly populate this item.
         flags = {'instancekey':identitykey, 'isnew':isnew}
-        if extension.populate_instance(self, context, row, instance, **flags) is EXT_CONTINUE:
-            self.populate_instance(context, instance, row, **flags)
-        if extension.append_result(self, context, row, instance, result, **flags) is EXT_CONTINUE:
-            if result is not None:
-                result.append(instance)
+        extension.populate_instance(self, context, row, instance, **flags)
+        extension.append_result(self, context, row, instance, result, **flags)
         return instance
 
-    def _create_instance(self, session):
-        obj = self.class_.__new__(self.class_)
-        obj._entity_name = self.entity_name
-
-        return obj
-
     def _deferred_inheritance_condition(self, needs_tables):
         cond = self.inherit_condition
 
@@ -1575,7 +1560,21 @@ class Mapper(object):
 Mapper.logger = logging.class_logger(Mapper)
 
 
+class _DefaultExtension(MapperExtension):
+    def translate_row(self, mapper, context, row):
+        return row
 
+    def populate_instance(self, mapper, context, row, instance, instancekey, isnew):
+        mapper.populate_instance(context, instance, row, instancekey=instancekey, isnew=isnew)
+        return EXT_STOP
+        
+    def append_result(self, mapper, context, row, instance, result, instancekey, isnew):
+        if result is not None:
+            result.append(instance)
+        return EXT_STOP
+        
+    def create_instance(self, mapper, context, row, class_):
+        return class_.__new__(class_)
 
 class ClassKey(object):
     """Key a class and an entity name to a mapper, via the mapper_registry."""
index 5dad7412456a5ec4a1ca7c4fafe107b3a93da812..aebcfcdfe5892c91ca94ba2e3d69d26c7fc95510 100644 (file)
@@ -104,7 +104,7 @@ class _ScopedExt(MapperExtension):
             def __call__(s):
                 return self.context.registry().query(class_)
 
-        if not hasattr(class_, 'query')
+        if not 'query' in class_.__dict__
             class_.query = query()
         
     def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
index 4fa9a4067ad59d19f77419381185df124df0b813..6263a2e525a838ae63a368b63d7a0689ab2c3c34 100644 (file)
@@ -390,6 +390,7 @@ class Session(object):
             
         """
         self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
+        self.identity_map = self.uow.identity_map
 
         self.bind = bind
         self.__binds = {}
@@ -573,6 +574,7 @@ class Session(object):
             self._unattach(instance)
         echo = self.uow.echo
         self.uow = unitofwork.UnitOfWork(weak_identity_map=self.weak_identity_map)
+        self.identity_map = self.uow.identity_map
         self.uow.echo = echo
 
     def bind_mapper(self, mapper, bind, entity_name=None):
@@ -853,7 +855,7 @@ class Session(object):
         try:
             key = getattr(object, '_instance_key', None)
             if key is None:
-                merged = mapper._create_instance(self)
+                merged = mapper.class_.__new__(mapper.class_)
             else:
                 if key in self.identity_map:
                     merged = self.identity_map[key]
@@ -1033,11 +1035,6 @@ class Session(object):
     def _get(self, key):
         return self.identity_map[key]
 
-    def has_key(self, key):
-        """return True if the given identity key is present within this Session's identity map."""
-        
-        return key in self.identity_map
-
     dirty = property(lambda s:s.uow.locate_dirty(),
                      doc="A ``Set`` of all objects marked as 'dirty' within this ``Session``")
 
@@ -1047,10 +1044,6 @@ class Session(object):
     new = property(lambda s:s.uow.new,
                    doc="A ``Set`` of all objects marked as 'new' within this ``Session``.")
 
-    identity_map = property(lambda s:s.uow.identity_map,
-                            doc="A dictionary consisting of all objects "
-                            "within this ``Session`` keyed to their `_instance_key` value.")
-
     def import_instance(self, *args, **kwargs):
         """A synynom for ``merge()``."""
 
index fb2707d4dce75925f230e127d1d064181f280945..59fa2891ee2f57984821ea3d965c9b3153061ade 100644 (file)
@@ -19,30 +19,30 @@ class EagerTest(AssertMixin):
             false = bp(false)
         
         owners = Table ( 'owners', dbmeta ,
-               Column ( 'id', Integer, primary_key=True, nullable=False ),
-               Column('data', String(30)) )
+            Column ( 'id', Integer, primary_key=True, nullable=False ),
+            Column('data', String(30)) )
         categories=Table( 'categories', dbmeta,
-               Column ( 'id', Integer,primary_key=True, nullable=False ),
-               Column ( 'name', VARCHAR(20), index=True ) )
+            Column ( 'id', Integer,primary_key=True, nullable=False ),
+            Column ( 'name', VARCHAR(20), index=True ) )
         tests = Table ( 'tests', dbmeta ,
-               Column ( 'id', Integer, primary_key=True, nullable=False ),
-               Column ( 'owner_id',Integer, ForeignKey('owners.id'), nullable=False,index=True ),
-               Column ( 'category_id', Integer, ForeignKey('categories.id'),nullable=False,index=True ))
+            Column ( 'id', Integer, primary_key=True, nullable=False ),
+            Column ( 'owner_id',Integer, ForeignKey('owners.id'), nullable=False,index=True ),
+            Column ( 'category_id', Integer, ForeignKey('categories.id'),nullable=False,index=True ))
         options = Table ( 'options', dbmeta ,
-               Column ( 'test_id', Integer, ForeignKey ( 'tests.id' ), primary_key=True, nullable=False ),
-               Column ( 'owner_id', Integer, ForeignKey ( 'owners.id' ), primary_key=True, nullable=False ),
-               Column ( 'someoption', Boolean, PassiveDefault(str(false)), nullable=False ) )
+            Column ( 'test_id', Integer, ForeignKey ( 'tests.id' ), primary_key=True, nullable=False ),
+            Column ( 'owner_id', Integer, ForeignKey ( 'owners.id' ), primary_key=True, nullable=False ),
+            Column ( 'someoption', Boolean, PassiveDefault(str(false)), nullable=False ) )
 
         dbmeta.create_all()
 
         class Owner(object):
-               pass
+            pass
         class Category(object):
-               pass
+            pass
         class Test(object):
-               pass
+            pass
         class Option(object):
-               pass
+            pass
         mapper(Owner,owners)
         mapper(Category,categories)
         mapper(Option,options,properties={'owner':relation(Owner),'test':relation(Test)})
@@ -66,17 +66,17 @@ class EagerTest(AssertMixin):
         s.save(c)
 
         for i in range(3):
-               t=Test()
-               t.owner=o
-               t.category=c
-               s.save(t)
-               if i==1:
-                       op=Option()
-                       op.someoption=True
-                       t.owner_option=op
-               if i==2:
-                       op=Option()
-                       t.owner_option=op
+            t=Test()
+            t.owner=o
+            t.category=c
+            s.save(t)
+            if i==1:
+                op=Option()
+                op.someoption=True
+                t.owner_option=op
+            if i==2:
+                op=Option()
+                t.owner_option=op
 
         s.flush()
         s.close()
@@ -99,16 +99,16 @@ class EagerTest(AssertMixin):
         # not orm style correct query
         print "Obtaining correct results without orm"
         result = select( [tests.c.id,categories.c.name],
-               and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
-               order_by=[tests.c.id],
-               from_obj=[tests.join(categories).outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))] ).execute().fetchall()
+            and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
+            order_by=[tests.c.id],
+            from_obj=[tests.join(categories).outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))] ).execute().fetchall()
         print result
         assert result == [(1, u'Some Category'), (3, u'Some Category')]
     
     def test_withouteagerload(self):
         s = create_session()
         l=s.query(Test).select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
-               from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
+            from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
         result = ["%d %s" % ( t.id,t.category.name ) for t in l]
         print result
         assert result == [u'1 Some Category', u'3 Some Category']
@@ -119,7 +119,7 @@ class EagerTest(AssertMixin):
         s = create_session()
         q=s.query(Test).options(eagerload('category'))
         l=q.select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
-               from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
+            from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
         result = ["%d %s" % ( t.id,t.category.name ) for t in l]
         print result
         assert result == [u'1 Some Category', u'3 Some Category']
@@ -140,7 +140,7 @@ class EagerTest(AssertMixin):
         s = create_session()
         q=s.query(Test).options(eagerload('category'))
         l=q.select( (tests.c.owner_id==1) & ('options.someoption is null or options.someoption=%s' % false) & q.join_to('owner_option') )
-        result = ["%d %s" % ( t.id,t.category.name ) for t in l]       
+        result = ["%d %s" % ( t.id,t.category.name ) for t in l]    
         print result
         assert result == [u'3 Some Category']
 
@@ -148,7 +148,7 @@ class EagerTest(AssertMixin):
         s = create_session()
         q=s.query(Test).options(eagerload('category'))
         l=q.select( (tests.c.owner_id==1) & ((options.c.someoption==None) | (options.c.someoption==False)) & q.join_to('owner_option') )
-        result = ["%d %s" % ( t.id,t.category.name ) for t in l]       
+        result = ["%d %s" % ( t.id,t.category.name ) for t in l]    
         print result
         assert result == [u'3 Some Category']
 
@@ -421,23 +421,23 @@ class EagerTest6(ORMTest):
     def define_tables(self, metadata):
         global designType, design, part, inheritedPart
         designType = Table('design_types', metadata, 
-               Column('design_type_id', Integer, primary_key=True),
-               )
+            Column('design_type_id', Integer, primary_key=True),
+            )
 
         design =Table('design', metadata, 
-               Column('design_id', Integer, primary_key=True),
-               Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
+            Column('design_id', Integer, primary_key=True),
+            Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
 
         part = Table('parts', metadata, 
-               Column('part_id', Integer, primary_key=True),
-               Column('design_id', Integer, ForeignKey('design.design_id')),
-               Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
+            Column('part_id', Integer, primary_key=True),
+            Column('design_id', Integer, ForeignKey('design.design_id')),
+            Column('design_type_id', Integer, ForeignKey('design_types.design_type_id')))
 
         inheritedPart = Table('inherited_part', metadata,
-               Column('ip_id', Integer, primary_key=True),
-               Column('part_id', Integer, ForeignKey('parts.part_id')),
-               Column('design_id', Integer, ForeignKey('design.design_id')),
-               )
+            Column('ip_id', Integer, primary_key=True),
+            Column('part_id', Integer, ForeignKey('parts.part_id')),
+            Column('design_id', Integer, ForeignKey('design.design_id')),
+            )
 
     def testone(self):
         class Part(object):pass
@@ -448,16 +448,16 @@ class EagerTest6(ORMTest):
         mapper(Part, part)
 
         mapper(InheritedPart, inheritedPart, properties=dict(
-               part=relation(Part, lazy=False)
+            part=relation(Part, lazy=False)
         ))
 
         mapper(Design, design, properties=dict(
-               parts=relation(Part, private=True, backref="design"),
-               inheritedParts=relation(InheritedPart, private=True, backref="design"),
+            parts=relation(Part, private=True, backref="design"),
+            inheritedParts=relation(InheritedPart, private=True, backref="design"),
         ))
 
         mapper(DesignType, designType, properties=dict(
-        #      designs=relation(Design, private=True, backref="type"),
+        #   designs=relation(Design, private=True, backref="type"),
         ))
 
         class_mapper(Design).add_property("type", relation(DesignType, lazy=False, backref="designs"))
index 3cca2f7f1c9e501324c6265280a06913f9ea97d9..0c824d372d5a09544c1062f317ac43c43d515d4d 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy import *
 from sqlalchemy import ansisql
 from sqlalchemy.orm import *
 from testlib import *
-from fixtures import *
+from testlib.fixtures import *
 
 from query import QueryTest
 
index dffc5322c8a981fa144cfd87c6faef4cd31ccfca..7751e6372a03eb0dadf9d3e2084fcd3fc57878d5 100644 (file)
@@ -4,7 +4,7 @@ import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from testlib import *
-from fixtures import *
+from testlib.fixtures import *
 from query import QueryTest
 
 class EagerTest(QueryTest):
index f5c2bb0542b92a90ef26da48ceb84cf6ddb803bf..a38e548741536386ff8e7a647d1b77df440e42dc 100644 (file)
@@ -2,54 +2,8 @@ import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from testlib import *
+from testlib.fixtures import Base
 
-# TODO: refactor "fixtures" to be part of testlib, so Base is globally available
-_recursion_stack = util.Set()
-class Base(object):
-    def __init__(self, **kwargs):
-        for k in kwargs:
-            setattr(self, k, kwargs[k])
-    
-    def __ne__(self, other):
-        return not self.__eq__(other)
-        
-    def __eq__(self, other):
-        """'passively' compare this object to another.
-        
-        only look at attributes that are present on the source object.
-        
-        """
-        if self in _recursion_stack:
-            return True
-        _recursion_stack.add(self)
-        try:
-            # use __dict__ to avoid instrumented properties
-            for attr in self.__dict__.keys():
-                if attr[0] == '_':
-                    continue
-                value = getattr(self, attr)
-                if hasattr(value, '__iter__') and not isinstance(value, basestring):
-                    try:
-                        # catch AttributeError so that lazy loaders trigger
-                        otherattr = getattr(other, attr)
-                    except AttributeError:
-                        return False
-                    if len(value) != len(getattr(other, attr)):
-                       return False
-                    for (us, them) in zip(value, getattr(other, attr)):
-                        if us != them:
-                            return False
-                    else:
-                        continue
-                else:
-                    if value is not None:
-                        print "KEY", attr, "COMPARING", value, "TO", getattr(other, attr, None)
-                        if value != getattr(other, attr, None):
-                            return False
-            else:
-                return True
-        finally:
-            _recursion_stack.remove(self)
 
 class InheritingSelectablesTest(ORMTest):
     def define_tables(self, metadata):
index 4fe3b354ac89f732ddeb83168754f0ef0bf4ea4f..e4a9c0c19f3040ec93a0077c96afdcc8f4026700 100644 (file)
@@ -4,7 +4,7 @@ import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from testlib import *
-from fixtures import *
+from testlib.fixtures import *
 from query import QueryTest
 
 class LazyTest(QueryTest):
index 9a7a437da975e157e0e21f6cded9a754d0acd90f..e0b8bf4f3f7911e7363d6f46bc3869aa7d5a2bce 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy import *
 from sqlalchemy import ansisql
 from sqlalchemy.orm import *
 from testlib import *
-from fixtures import *
+from testlib.fixtures import *
 
 class QueryTest(FixtureTest):
     keep_mappers = True
index 920cd9d8f389375d417676a90cfc4ab33db559a1..7f2fe0b6d273c790b3600367fbb5f928bc650634 100644 (file)
@@ -4,7 +4,7 @@ import testbase
 from sqlalchemy import *
 from sqlalchemy.orm import *
 from testlib import *
-from fixtures import *
+from testlib.fixtures import *
 from query import QueryTest
 
 class SelectableNoFromsTest(ORMTest):
index 7857bf0647e0ca6ab5552ac25c24bef423a3b8c3..337acd51daa53922931d28119136d24b28d1339c 100644 (file)
@@ -5,17 +5,20 @@ from sqlalchemy.orm.session import Session as SessionCls
 from testlib import *
 from testlib.tables import *
 import testlib.tables as tables
-import fixtures
+from testlib import fixtures
 
 class SessionTest(AssertMixin):
     def setUpAll(self):
         tables.create()
+        
     def tearDownAll(self):
         tables.drop()
+        
     def tearDown(self):
         SessionCls.close_all()
         tables.delete()
         clear_mappers()
+        
     def setUp(self):
         pass
 
@@ -614,7 +617,46 @@ class ScopedMapperTest(PersistTest):
         Session.mapper(MyClass, table2)
 
         assert MyClass().expunge() == "an expunge !"
+
+class ScopedMapperTest2(ORMTest):
+    def define_tables(self, metadata):
+        global table, table2
+        table = Table('sometable', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30)),
+            Column('type', String(30))
+            
+            )
+        table2 = Table('someothertable', metadata, 
+            Column('id', Integer, primary_key=True),
+            Column('someid', None, ForeignKey('sometable.id')),
+            Column('somedata', String(30)),
+            )
     
+    def test_inheritance(self):
+        def expunge_list(l):
+            for x in l:
+                Session.expunge(x)
+            return l
+            
+        class BaseClass(fixtures.Base):
+            pass
+        class SubClass(BaseClass):
+            pass
+        
+        Session = scoped_session(sessionmaker())
+        Session.mapper(BaseClass, table, polymorphic_identity='base', polymorphic_on=table.c.type)
+        Session.mapper(SubClass, table2, polymorphic_identity='sub', inherits=BaseClass)
+        
+        b = BaseClass(data='b1')
+        s =  SubClass(data='s1', somedata='somedata')
+        Session.commit()
+        Session.clear()
+        
+        assert expunge_list([BaseClass(data='b1'), SubClass(data='s1', somedata='somedata')]) == BaseClass.query.all()
+        assert expunge_list([SubClass(data='s1', somedata='somedata')]) == SubClass.query.all()
+        
+        
 
 if __name__ == "__main__":    
     testbase.main()
index df678cf099a1779be43938adf788acab73712758..90928dfd8e9c47d9a6a1a39283ebbef6d103b461 100644 (file)
@@ -6,7 +6,7 @@ from sqlalchemy.orm import *
 from testlib import *
 from testlib.tables import *
 from testlib import tables
-import fixtures
+from testlib import fixtures
 
 """tests unitofwork operations"""
 
similarity index 95%
rename from test/orm/fixtures.py
rename to test/testlib/fixtures.py
index ead4bc95105ee6d82a3d038f56a2bb1a22b9e91d..ab44838e7e83fffd96567aa3e901b3fc5db8f737 100644 (file)
@@ -24,6 +24,7 @@ class Base(object):
         only look at attributes that are present on the source object.
         
         """
+
         if self in _recursion_stack:
             return True
         _recursion_stack.add(self)
@@ -38,18 +39,21 @@ class Base(object):
                         # catch AttributeError so that lazy loaders trigger
                         otherattr = getattr(other, attr)
                     except AttributeError:
+                        print "Other class does not have attribute named '%s'" % attr
                         return False
                     if len(value) != len(getattr(other, attr)):
-                       return False
+                        print "Length of collection '%s' does not match that of other" % attr
+                        return False
                     for (us, them) in zip(value, getattr(other, attr)):
                         if us != them:
+                            print "1. Attribute named '%s' does not match other" % attr
                             return False
                     else:
                         continue
                 else:
                     if value is not None:
-                        print "KEY", attr, "COMPARING", value, "TO", getattr(other, attr, None)
                         if value != getattr(other, attr, None):
+                            print "2. Attribute named '%s' does not match that of other" % attr
                             return False
             else:
                 return True