from sqlalchemy.orm import mapper, relation, clear_mappers, create_session
from sqlalchemy.orm.mapper import Mapper
from testlib import *
+from testlib.fixtures import Base
-class A(object):pass
-class B(object):pass
+class A(Base):pass
+class B(Base):pass
-class MapperCleanoutTest(AssertMixin):
- """test that clear_mappers() removes everything related to the class.
-
- does not include classes that use the assignmapper extension."""
-
- def test_mapper_cleanup(self):
- for x in range(0, 5):
- self.do_test()
+def profile_memory(func):
+ # run the test 50 times. if length of gc.get_objects()
+ # keeps growing, assert false
+ def profile(*args):
+ samples = []
+ for x in range(0, 50):
+ func(*args)
gc.collect()
- for o in gc.get_objects():
- if isinstance(o, Mapper):
- # the classes in the 'tables' package have assign_mapper called on them
- # which is particularly sticky
- # if getattr(tables, o.class_.__name__, None) is o.class_:
- # continue
- # well really we are just testing our own classes here
- if (o.class_ not in [A,B]):
- continue
- assert False
+ samples.append(len(gc.get_objects()))
+ print "sample gc sizes:", samples
+ # TODO: this test only finds pure "growing" tests
+ for i, x in enumerate(samples):
+ if i < len(samples) - 1 and samples[i+1] <= x:
+ break
+ else:
+ assert False
assert True
-
- def do_test(self):
+ return profile
+
+class MemUsageTest(AssertMixin):
+
+ def test_session(self):
metadata = MetaData(testbase.db)
table1 = Table("mytable", metadata,
metadata.create_all()
-
m1 = mapper(A, table1, properties={
- "bs":relation(B)
+ "bs":relation(B, cascade="all, delete")
})
m2 = mapper(B, table2)
m3 = mapper(A, table1, non_primary=True)
- sess = create_session()
- a1 = A()
- a2 = A()
- a3 = A()
- a1.bs.append(B())
- a1.bs.append(B())
- a3.bs.append(B())
- for x in [a1,a2,a3]:
- sess.save(x)
- sess.flush()
- sess.clear()
+ @profile_memory
+ def go():
+ sess = create_session()
+ a1 = A(col2="a1")
+ a2 = A(col2="a2")
+ a3 = A(col2="a3")
+ a1.bs.append(B(col2="b1"))
+ a1.bs.append(B(col2="b2"))
+ a3.bs.append(B(col2="b3"))
+ for x in [a1,a2,a3]:
+ sess.save(x)
+ sess.flush()
+ sess.clear()
- alist = sess.query(A).select()
- for a in alist:
- print "A", a, "BS", [b for b in a.bs]
-
+ alist = sess.query(A).all()
+ self.assertEquals(
+ [
+ A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]),
+ A(col2="a2", bs=[]),
+ A(col2="a3", bs=[B(col2="b3")])
+ ],
+ alist)
+
+ for a in alist:
+ sess.delete(a)
+ sess.flush()
+ go()
+
metadata.drop_all()
clear_mappers()
+
+ def test_mapper_reset(self):
+ metadata = MetaData(testbase.db)
+
+ table1 = Table("mytable", metadata,
+ Column('col1', Integer, primary_key=True),
+ Column('col2', String(30))
+ )
+
+ table2 = Table("mytable2", metadata,
+ Column('col1', Integer, primary_key=True),
+ Column('col2', String(30)),
+ Column('col3', Integer, ForeignKey("mytable.col1"))
+ )
+
+ @profile_memory
+ def go():
+ m1 = mapper(A, table1, properties={
+ "bs":relation(B)
+ })
+ m2 = mapper(B, table2)
+
+ m3 = mapper(A, table1, non_primary=True)
+
+ sess = create_session()
+ a1 = A(col2="a1")
+ a2 = A(col2="a2")
+ a3 = A(col2="a3")
+ a1.bs.append(B(col2="b1"))
+ a1.bs.append(B(col2="b2"))
+ a3.bs.append(B(col2="b3"))
+ for x in [a1,a2,a3]:
+ sess.save(x)
+ sess.flush()
+ sess.clear()
+
+ alist = sess.query(A).all()
+ self.assertEquals(
+ [
+ A(col2="a1", bs=[B(col2="b1"), B(col2="b2")]),
+ A(col2="a2", bs=[]),
+ A(col2="a3", bs=[B(col2="b3")])
+ ],
+ alist)
+
+ for a in alist:
+ sess.delete(a)
+ sess.flush()
+ clear_mappers()
+
+ metadata.create_all()
+ try:
+ go()
+ finally:
+ metadata.drop_all()
+
if __name__ == '__main__':
testbase.main()
return True
_recursion_stack.add(self)
try:
- # use __dict__ to avoid instrumented properties
- for attr in self.__dict__.keys():
+ # pick the entity thats not SA persisted as the source
+ if other is None:
+ a = self
+ b = other
+ elif hasattr(self, '_instance_key'):
+ a = other
+ b = self
+ else:
+ a = self
+ b = other
+
+ for attr in a.__dict__.keys():
if attr[0] == '_':
continue
- value = getattr(self, attr)
+ value = getattr(a, attr)
+ #print "looking at attr:", attr, "start value:", value
if hasattr(value, '__iter__') and not isinstance(value, basestring):
try:
# catch AttributeError so that lazy loaders trigger
- otherattr = getattr(other, attr)
+ battr = getattr(b, attr)
except AttributeError:
- #print "Other class does not have attribute named '%s'" % attr
+ #print "b class does not have attribute named '%s'" % attr
return False
+ #print "other:", battr
if not hasattr(value, '__len__'):
value = list(iter(value))
- otherattr = list(iter(otherattr))
- if len(value) != len(otherattr):
- #print "Length of collection '%s' does not match that of other" % attr
+ battr = list(iter(battr))
+ if len(value) != len(battr):
+ #print "Length of collection '%s' does not match that of b" % attr
return False
- for (us, them) in zip(value, otherattr):
+ for (us, them) in zip(value, battr):
if us != them:
- #print "1. Attribute named '%s' does not match other" % attr
+ #print "1. Attribute named '%s' does not match b" % attr
return False
else:
continue
else:
if value is not None:
- if value != getattr(other, attr, None):
- #print "2. Attribute named '%s' does not match that of other" % attr
+ if value != getattr(b, attr, None):
+ #print "2. Attribute named '%s' does not match that of b" % attr
return False
else:
return True