]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fix up the fixtures comparator
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Dec 2007 15:56:37 +0000 (15:56 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Dec 2007 15:56:37 +0000 (15:56 +0000)
- strengthened memory profiling test

test/orm/memusage.py
test/testlib/fixtures.py

index 26da7c010d6f29d828a494a1f3ac03def0539869..e8d98aa426e52f6cf945bde6975f2f180df9b2ee 100644 (file)
@@ -4,32 +4,33 @@ from sqlalchemy import MetaData, Integer, String, ForeignKey
 from sqlalchemy.orm import mapper, relation, clear_mappers, create_session
 from sqlalchemy.orm.mapper import Mapper
 from testlib import *
+from testlib.fixtures import Base
 
-class A(object):pass
-class B(object):pass
+class A(Base):pass
+class B(Base):pass
 
-class MapperCleanoutTest(AssertMixin):
-    """test that clear_mappers() removes everything related to the class.
-    
-    does not include classes that use the assignmapper extension."""
-
-    def test_mapper_cleanup(self):
-        for x in range(0, 5):
-            self.do_test()
+def profile_memory(func):
+    # run the test 50 times.  if length of gc.get_objects()
+    # keeps growing, assert false
+    def profile(*args):
+        samples = []
+        for x in range(0, 50):
+            func(*args)
             gc.collect()
-            for o in gc.get_objects():
-                if isinstance(o, Mapper):
-                    # the classes in the 'tables' package have assign_mapper called on them
-                    # which is particularly sticky
-                    # if getattr(tables, o.class_.__name__, None) is o.class_:
-                    #    continue
-                    # well really we are just testing our own classes here
-                    if (o.class_ not in [A,B]):
-                        continue
-                    assert False
+            samples.append(len(gc.get_objects()))
+        print "sample gc sizes:", samples
+        # TODO: this test only finds pure "growing" tests
+        for i, x in enumerate(samples):
+            if i < len(samples) - 1 and samples[i+1] <= x:
+                break
+        else:
+            assert False
         assert True
-        
-    def do_test(self):
+    return profile
+
+class MemUsageTest(AssertMixin):
+    
+    def test_session(self):
         metadata = MetaData(testbase.db)
 
         table1 = Table("mytable", metadata, 
@@ -45,32 +46,99 @@ class MapperCleanoutTest(AssertMixin):
     
         metadata.create_all()
 
-
         m1 = mapper(A, table1, properties={
-            "bs":relation(B)
+            "bs":relation(B, cascade="all, delete")
         })
         m2 = mapper(B, table2)
 
         m3 = mapper(A, table1, non_primary=True)
         
-        sess = create_session()
-        a1 = A()
-        a2 = A()
-        a3 = A()
-        a1.bs.append(B())
-        a1.bs.append(B())
-        a3.bs.append(B())
-        for x in [a1,a2,a3]:
-            sess.save(x)
-        sess.flush()
-        sess.clear()
+        @profile_memory
+        def go():
+            sess = create_session()
+            a1 = A(col2="a1")
+            a2 = A(col2="a2")
+            a3 = A(col2="a3")
+            a1.bs.append(B(col2="b1"))
+            a1.bs.append(B(col2="b2"))
+            a3.bs.append(B(col2="b3"))
+            for x in [a1,a2,a3]:
+                sess.save(x)
+            sess.flush()
+            sess.clear()
 
-        alist = sess.query(A).select()
-        for a in alist:
-            print "A", a, "BS", [b for b in a.bs]
-    
+            alist = sess.query(A).all()
+            self.assertEquals(
+                [
+                    A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]), 
+                    A(col2="a2", bs=[]), 
+                    A(col2="a3", bs=[B(col2="b3")])
+                ], 
+                alist)
+
+            for a in alist:
+                sess.delete(a)
+            sess.flush()
+        go()
+        
         metadata.drop_all()
         clear_mappers()
+        
+    def test_mapper_reset(self):
+        metadata = MetaData(testbase.db)
+
+        table1 = Table("mytable", metadata, 
+            Column('col1', Integer, primary_key=True),
+            Column('col2', String(30))
+            )
+
+        table2 = Table("mytable2", metadata, 
+            Column('col1', Integer, primary_key=True),
+            Column('col2', String(30)),
+            Column('col3', Integer, ForeignKey("mytable.col1"))
+            )
+
+        @profile_memory
+        def go():
+            m1 = mapper(A, table1, properties={
+                "bs":relation(B)
+            })
+            m2 = mapper(B, table2)
+
+            m3 = mapper(A, table1, non_primary=True)
+        
+            sess = create_session()
+            a1 = A(col2="a1")
+            a2 = A(col2="a2")
+            a3 = A(col2="a3")
+            a1.bs.append(B(col2="b1"))
+            a1.bs.append(B(col2="b2"))
+            a3.bs.append(B(col2="b3"))
+            for x in [a1,a2,a3]:
+                sess.save(x)
+            sess.flush()
+            sess.clear()
+
+            alist = sess.query(A).all()
+            self.assertEquals(
+                [
+                    A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]), 
+                    A(col2="a2", bs=[]), 
+                    A(col2="a3", bs=[B(col2="b3")])
+                ], 
+                alist)
+        
+            for a in alist:
+                sess.delete(a)
+            sess.flush()
+            clear_mappers()
+        
+        metadata.create_all()
+        try:
+            go()
+        finally:
+            metadata.drop_all()
+
     
 if __name__ == '__main__':
     testbase.main()
index 022fce094349807b406147592d4ba2231eaa6788..4394780bb20f3cecde9f93fa06fa6b6b3233ca3b 100644 (file)
@@ -30,34 +30,46 @@ class Base(object):
             return True
         _recursion_stack.add(self)
         try:
-            # use __dict__ to avoid instrumented properties
-            for attr in self.__dict__.keys():
+            # pick the entity thats not SA persisted as the source
+            if other is None:
+                a = self
+                b = other
+            elif hasattr(self, '_instance_key'):
+                a = other
+                b = self
+            else:
+                a = self
+                b = other
+            
+            for attr in a.__dict__.keys():
                 if attr[0] == '_':
                     continue
-                value = getattr(self, attr)
+                value = getattr(a, attr)
+                #print "looking at attr:", attr, "start value:", value
                 if hasattr(value, '__iter__') and not isinstance(value, basestring):
                     try:
                         # catch AttributeError so that lazy loaders trigger
-                        otherattr = getattr(other, attr)
+                        battr = getattr(b, attr)
                     except AttributeError:
-                        #print "Other class does not have attribute named '%s'" % attr
+                        #print "b class does not have attribute named '%s'" % attr
                         return False
+                    #print "other:", battr
                     if not hasattr(value, '__len__'):
                         value = list(iter(value))
-                        otherattr = list(iter(otherattr))
-                    if len(value) != len(otherattr):
-                        #print "Length of collection '%s' does not match that of other" % attr
+                        battr = list(iter(battr))
+                    if len(value) != len(battr):
+                        #print "Length of collection '%s' does not match that of b" % attr
                         return False
-                    for (us, them) in zip(value, otherattr):
+                    for (us, them) in zip(value, battr):
                         if us != them:
-                            #print "1. Attribute named '%s' does not match other" % attr
+                            #print "1. Attribute named '%s' does not match b" % attr
                             return False
                     else:
                         continue
                 else:
                     if value is not None:
-                        if value != getattr(other, attr, None):
-                            #print "2. Attribute named '%s' does not match that of other" % attr
+                        if value != getattr(b, attr, None):
+                            #print "2. Attribute named '%s' does not match that of b" % attr
                             return False
             else:
                 return True