]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- restored MapperExtension functionality for [ticket:829], added test coverage
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Oct 2007 16:12:29 +0000 (16:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 29 Oct 2007 16:12:29 +0000 (16:12 +0000)
- changed naming convention in mapper.py tests to test_<testname>

lib/sqlalchemy/orm/util.py
test/orm/mapper.py

index f4294502b1a2fcd98565f43052bb6a1de9c87809..9e3f20257fe1884d56ce5e6f41e0bb75ea6dcb5c 100644 (file)
@@ -122,8 +122,11 @@ class ExtensionCarrier(object):
     """
     
     def __init__(self, _elements=None):
-        self.__elements = _elements or []
         self.methods = {}
+        if _elements is not None:
+            self.__elements = [self.__inspect(e) for e in _elements]
+        else:
+            self.__elements = []
         
     def copy(self):
         return ExtensionCarrier(list(self.__elements))
index 9f11027af9e31b9d622988a0b5f05c544baeeb34..c9729944af65d6a09d549922af227c34e35df5a2 100644 (file)
@@ -23,7 +23,7 @@ class MapperSuperTest(AssertMixin):
     
 class MapperTest(MapperSuperTest):
 
-    def testpropconflict(self):
+    def test_propconflict(self):
         """test that a backref created against an existing mapper with a property name
         conflict raises a decent error message"""
         mapper(Address, addresses)
@@ -38,7 +38,7 @@ class MapperTest(MapperSuperTest):
         except exceptions.ArgumentError:
             pass
 
-    def testbadcascade(self):
+    def test_badcascade(self):
         mapper(Address, addresses)
         try:
             mapper(User, users, properties={'addresses':relation(Address, cascade="fake, all, delete-orphan")})
@@ -46,7 +46,7 @@ class MapperTest(MapperSuperTest):
         except exceptions.ArgumentError, e:
             assert str(e) == "Invalid cascade option 'fake'"
         
-    def testcolumnprefix(self):
+    def test_columnprefix(self):
         mapper(User, users, column_prefix='_', properties={ 
             'user_name':synonym('_user_name') 
         })
@@ -59,7 +59,7 @@ class MapperTest(MapperSuperTest):
         u2 = s.query(User).filter_by(user_name='jack').one() 
         assert u is u2
         
-    def testrefresh(self):
+    def test_refresh(self):
         mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), backref='user')})
         s = create_session()
         u = s.get(User, 7)
@@ -92,12 +92,12 @@ class MapperTest(MapperSuperTest):
         self.assert_(u.user_name == 'jack')
         self.assert_(a not in u.addresses)
 
-    def testcompileonsession(self):
+    def test_compileonsession(self):
         m = mapper(User, users)
         session = create_session()
         session.connection(m)        
 
-    def testexpirecascade(self):
+    def test_expirecascade(self):
         mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), cascade="all, refresh-expire")})
         s = create_session()
         u = s.get(User, 8)
@@ -105,7 +105,7 @@ class MapperTest(MapperSuperTest):
         s.expire(u)
         assert u.addresses[0].email_address == 'ed@wood.com'
         
-    def testrefreshwitheager(self):
+    def test_refreshwitheager(self):
         """test that a refresh/expire operation loads rows properly and sends correct "isnew" state to eager loaders"""
         mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)})
         s = create_session()
@@ -120,7 +120,7 @@ class MapperTest(MapperSuperTest):
         s.expire(u)
         assert len(u.addresses) == 3
     
-    def testincompletecolumns(self):
+    def test_incompletecolumns(self):
         """test loading from a select which does not contain all columns"""
         mapper(Address, addresses)
         s = create_session()
@@ -129,7 +129,7 @@ class MapperTest(MapperSuperTest):
         assert a.address_id == 1
         assert a.email_address is None
         
-    def testbadconstructor(self):
+    def test_badconstructor(self):
         """test that if the construction of a mapped class fails, the instnace does not get placed in the session"""
         class Foo(object):
             def __init__(self, one, two):
@@ -147,7 +147,7 @@ class MapperTest(MapperSuperTest):
         except TypeError, e:
             pass
 
-    def testconstructorexceptions(self):
+    def test_constructorexceptions(self):
         """test that exceptions raised in the mapped class are not masked by sa decorations""" 
         ex = AssertionError('oops')
         sess = create_session()
@@ -178,7 +178,7 @@ class MapperTest(MapperSuperTest):
         except Exception, e:
             assert e is ex
             
-    def testrefresh_lazy(self):
+    def test_refresh_lazy(self):
         """test that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesnt fire the lazy loader or create any problems"""
         s = create_session()
         mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))})
@@ -188,7 +188,7 @@ class MapperTest(MapperSuperTest):
             s.refresh(u)
         self.assert_sql_count(testbase.db, go, 1)
 
-    def testexpire(self):
+    def test_expire(self):
         """test the expire function"""
         s = create_session()
         mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)})
@@ -219,7 +219,7 @@ class MapperTest(MapperSuperTest):
         # this should *not* produce a SELECT statement (not tested here though....)
         self.assert_(u.user_name =='jack')
         
-    def testrefresh2(self):
+    def test_refresh2(self):
         """test a hang condition that was occuring on expire/refresh"""
         
         s = create_session()
@@ -242,13 +242,13 @@ class MapperTest(MapperSuperTest):
 
         s.refresh(u) #hangs
         
-    def testprops(self):
+    def test_props(self):
         m = mapper(User, users, properties = {
             'addresses' : relation(mapper(Address, addresses))
         }).compile()
         self.assert_(User.addresses.property is m.get_property('addresses'))
     
-    def testcompileonprop(self):
+    def test_compileonprop(self):
         mapper(User, users, properties = {
             'addresses' : relation(mapper(Address, addresses))
         })
@@ -267,7 +267,7 @@ class MapperTest(MapperSuperTest):
         mapper(Foo, addresses, inherits=User)
         assert getattr(Foo().__class__, 'user_name').impl is not None
     
-    def testaddproperty(self):
+    def test_addproperty(self):
         m = mapper(User, users)
         mapper(Address, addresses)
         m.add_property('user_name', deferred(users.c.user_name))
@@ -291,7 +291,7 @@ class MapperTest(MapperSuperTest):
         sess.flush()
         sess.rollback()
         
-    def testpropfilters(self):
+    def test_propfilters(self):
         t = Table('person', MetaData(),
                   Column('id', Integer, primary_key=True),
                   Column('type', String),
@@ -342,7 +342,7 @@ class MapperTest(MapperSuperTest):
         assert_props(Hoho, ['id', 'name', 'type'])
         assert_props(Lala, ['p_employee_number', 'p_id', 'p_name', 'p_type'])
 
-    def testrecursiveselectby(self):
+    def test_recursiveselectby(self):
         """test that no endless loop occurs when traversing for select_by"""
         m = mapper(User, users, properties={
             'orders':relation(mapper(Order, orders), backref='user'),
@@ -351,7 +351,7 @@ class MapperTest(MapperSuperTest):
         q = create_session().query(m)
         q.select_by(email_address='foo')
 
-    def testmappingtojoin(self):
+    def test_mappingtojoin(self):
         """test mapping to a join"""
         usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
         m = mapper(User, usersaddresses, primary_key=[users.c.user_id])
@@ -359,7 +359,7 @@ class MapperTest(MapperSuperTest):
         l = q.select()
         self.assert_result(l, User, *user_result[0:2])
     
-    def testmappingtojoinnopk(self):
+    def test_mappingtojoinnopk(self):
         metadata = MetaData()
         account_ids_table = Table('account_ids', metadata,
                 Column('account_id', Integer, primary_key=True),
@@ -383,7 +383,7 @@ class MapperTest(MapperSuperTest):
         finally:
             metadata.drop_all(testbase.db)
         
-    def testmappingtoouterjoin(self):
+    def test_mappingtoouterjoin(self):
         """test mapping to an outer join, with a composite primary key that allows nulls"""
         result = [
         {'user_id' : 7, 'address_id' : 1},
@@ -400,7 +400,7 @@ class MapperTest(MapperSuperTest):
         self.assert_result(l, User, *result)
 
         
-    def testcustomjoin(self):
+    def test_customjoin(self):
         """test that the from_obj parameter to query.select() can be used
         to totally replace the FROM parameters of the generated query."""
 
@@ -414,7 +414,7 @@ class MapperTest(MapperSuperTest):
         l = q.select((orderitems.c.item_name=='item 4'), from_obj=[users.join(orders).join(orderitems)])
         self.assert_result(l, User, user_result[0])
             
-    def testorderby(self):
+    def test_orderby(self):
         """test ordering at the mapper and query level"""
         # TODO: make a unit test out of these various combinations
 #        m = mapper(User, users, order_by=desc(users.c.user_name))
@@ -428,7 +428,7 @@ class MapperTest(MapperSuperTest):
         
         
     @testing.unsupported('firebird') 
-    def testfunction(self):
+    def test_function(self):
         """test mapping to a SELECT statement that has functions in it."""
         s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')],
         users.c.user_id==addresses.c.user_id, group_by=[c for c in users.c]).alias('myselect')
@@ -441,7 +441,7 @@ class MapperTest(MapperSuperTest):
         assert l[1].concat == l[1].user_id * 2 == 16
 
     @testing.unsupported('firebird') 
-    def testcount(self):
+    def test_count(self):
         """test the count function on Query.
         
         (why doesnt this work on firebird?)"""
@@ -451,14 +451,14 @@ class MapperTest(MapperSuperTest):
         self.assert_(q.count(users.c.user_id.in_([8,9]))==2)
         self.assert_(q.count_by(user_name='fred')==1)
 
-    def testmanytomany_count(self):
+    def test_manytomany_count(self):
         mapper(Item, orderitems, properties = dict(
                 keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True),
             ))
         q = create_session().query(Item)
         assert q.join('keywords').distinct().count(Keyword.c.name=="red") == 2
 
-    def testoverride(self):
+    def test_override(self):
         # assert that overriding a column raises an error
         try:
             m = mapper(User, users, properties = {
@@ -481,7 +481,7 @@ class MapperTest(MapperSuperTest):
                 'foo' : users.c.user_name,
             })
 
-    def testsynonym(self):
+    def test_synonym(self):
         sess = create_session()
         mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy = True),
@@ -510,7 +510,7 @@ class MapperTest(MapperSuperTest):
         assert u.user_name == "some user name"
         assert u in sess.dirty
 
-    def testsynonymoptions(self):
+    def test_synonymoptions(self):
         sess = create_session()
         mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy = True),
@@ -522,7 +522,7 @@ class MapperTest(MapperSuperTest):
             self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1]))
         self.assert_sql_count(testbase.db, go, 1)
         
-    def testextensionoptions(self):
+    def test_extensionoptions(self):
         sess  = create_session()
         class ext1(MapperExtension):
             def populate_instance(self, mapper, selectcontext, row, instance, **flags):
@@ -549,7 +549,7 @@ class MapperTest(MapperSuperTest):
         assert not hasattr(l.addresses[0], 'TEST')
         assert not hasattr(l.addresses[0], 'TEST2')
         
-    def testeageroptions(self):
+    def test_eageroptions(self):
         """tests that a lazy relation can be upgraded to an eager relation via the options method"""
         sess = create_session()
         mapper(User, users, properties = dict(
@@ -561,7 +561,7 @@ class MapperTest(MapperSuperTest):
             self.assert_result(l, User, *user_address_result)
         self.assert_sql_count(testbase.db, go, 0)
 
-    def testeageroptionswithlimit(self):
+    def test_eageroptionswithlimit(self):
         sess = create_session()
         mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy = True)
@@ -582,7 +582,7 @@ class MapperTest(MapperSuperTest):
             assert len(u.addresses) == 3
         assert "tbl_row_count" not in self.capture_sql(testbase.db, go)
         
-    def testlazyoptionswithlimit(self):
+    def test_lazyoptionswithlimit(self):
         sess = create_session()
         mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy = False)
@@ -594,7 +594,7 @@ class MapperTest(MapperSuperTest):
             assert len(u.addresses) == 3
         self.assert_sql_count(testbase.db, go, 1)
 
-    def testeagerdegrade(self):
+    def test_eagerdegrade(self):
         """tests that an eager relation automatically degrades to a lazy relation if eager columns are not available"""
         sess = create_session()
         usermapper = mapper(User, users, properties = dict(
@@ -656,7 +656,7 @@ class MapperTest(MapperSuperTest):
         self.assert_sql_count(testbase.db, go, 7)
         
         
-    def testlazyoptions(self):
+    def test_lazyoptions(self):
         """tests that an eager relation can be upgraded to a lazy relation via the options method"""
         sess = create_session()
         mapper(User, users, properties = dict(
@@ -667,7 +667,7 @@ class MapperTest(MapperSuperTest):
             self.assert_result(l, User, *user_address_result)
         self.assert_sql_count(testbase.db, go, 3)
 
-    def testlatecompile(self):
+    def test_latecompile(self):
         """tests mappers compiling late in the game"""
         
         mapper(User, users, properties = {'orders': relation(Order)})
@@ -681,7 +681,7 @@ class MapperTest(MapperSuperTest):
             print u[0].orders[1].items[0].keywords[1]
         self.assert_sql_count(testbase.db, go, 3)
 
-    def testdeepoptions(self):
+    def test_deepoptions(self):
         mapper(User, users,
             properties = {
                 'orders': relation(mapper(Order, orders, properties = {
@@ -733,7 +733,7 @@ class MapperTest(MapperSuperTest):
     
 class DeferredTest(MapperSuperTest):
 
-    def testbasic(self):
+    def test_basic(self):
         """tests a basic "deferred" load"""
         
         m = mapper(Order, orders, properties={
@@ -755,7 +755,7 @@ class DeferredTest(MapperSuperTest):
             ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :param_1", {'param_1':3})
         ])
 
-    def testunsaved(self):
+    def test_unsaved(self):
         """test that deferred loading doesnt kick in when just PK cols are set"""
         m = mapper(Order, orders, properties={
             'description':deferred(orders.c.description)
@@ -769,7 +769,7 @@ class DeferredTest(MapperSuperTest):
             o.description = "some description"
         self.assert_sql_count(testbase.db, go, 0)
 
-    def testunsavedgroup(self):
+    def test_unsavedgroup(self):
         """test that deferred loading doesnt kick in when just PK cols are set"""
         m = mapper(Order, orders, properties={
             'description':deferred(orders.c.description, group='primary'),
@@ -784,7 +784,7 @@ class DeferredTest(MapperSuperTest):
             o.description = "some description"
         self.assert_sql_count(testbase.db, go, 0)
         
-    def testsave(self):
+    def test_save(self):
         m = mapper(Order, orders, properties={
             'description':deferred(orders.c.description)
         })
@@ -796,7 +796,7 @@ class DeferredTest(MapperSuperTest):
         o2.isopen = 1
         sess.flush()
         
-    def testgroup(self):
+    def test_group(self):
         """tests deferred load with a group"""
         m = mapper(Order, orders, properties = {
             'userident':deferred(orders.c.user_id, group='primary'),
@@ -827,7 +827,7 @@ class DeferredTest(MapperSuperTest):
             sess.flush()
         self.assert_sql_count(testbase.db, go, 0)
     
-    def testcommitsstate(self):
+    def test_commitsstate(self):
         """test that when deferred elements are loaded via a group, they get the proper CommittedState
         and dont result in changes being committed"""
         
@@ -849,7 +849,7 @@ class DeferredTest(MapperSuperTest):
             sess.flush()
         self.assert_sql_count(testbase.db, go, 0)
             
-    def testoptions(self):
+    def test_options(self):
         """tests using options on a mapper to create deferred and undeferred columns"""
         m = mapper(Order, orders)
         sess = create_session()
@@ -873,7 +873,7 @@ class DeferredTest(MapperSuperTest):
             ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
         ])
 
-    def testundefergroup(self):
+    def test_undefergroup(self):
         """tests undefer_group()"""
         m = mapper(Order, orders, properties = {
             'userident':deferred(orders.c.user_id, group='primary'),
@@ -895,7 +895,7 @@ class DeferredTest(MapperSuperTest):
         ])
 
         
-    def testdeepoptions(self):
+    def test_deepoptions(self):
         m = mapper(User, users, properties={
             'orders':relation(mapper(Order, orders, properties={
                 'items':relation(mapper(Item, orderitems, properties={
@@ -1051,7 +1051,7 @@ class CompositeTypesTest(ORMTest):
         
         
 class NoLoadTest(MapperSuperTest):
-    def testbasic(self):
+    def test_basic(self):
         """tests a basic one-to-many lazy load"""
         m = mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy=None)
@@ -1067,7 +1067,7 @@ class NoLoadTest(MapperSuperTest):
         self.assert_result(l[0], User,
             {'user_id' : 7, 'addresses' : (Address, [])},
             )
-    def testoptions(self):
+    def test_options(self):
         m = mapper(User, users, properties = dict(
             addresses = relation(mapper(Address, addresses), lazy=None)
         ))
@@ -1084,7 +1084,17 @@ class NoLoadTest(MapperSuperTest):
             )
 
 class MapperExtensionTest(MapperSuperTest):
-    def testcreateinstance(self):
+    def setUpAll(self):
+        tables.create()
+    def tearDownAll(self):
+        tables.drop()
+    def tearDown(self):
+        clear_mappers()
+        tables.delete()
+    def setUp(self):
+        tables.data()
+
+    def test_create_instance(self):
         class Ext(MapperExtension):
             def create_instance(self, *args, **kwargs):
                 return User()
@@ -1097,6 +1107,74 @@ class MapperExtensionTest(MapperSuperTest):
         l = q.select();
         self.assert_result(l, User, *user_address_result)
     
+    def test_methods(self):
+        """test that common user-defined methods get called."""
+        
+        methods = set()
+        class Ext(MapperExtension):
+            def load(self, query, *args, **kwargs):
+                methods.add('load')
+                return EXT_CONTINUE
+
+            def get(self, query, *args, **kwargs):
+                methods.add('get')
+                return EXT_CONTINUE
+
+            def translate_row(self, mapper, context, row):
+                methods.add('translate_row')
+                return EXT_CONTINUE
+
+            def create_instance(self, mapper, selectcontext, row, class_):
+                methods.add('create_instance')
+                return EXT_CONTINUE
+
+            def append_result(self, mapper, selectcontext, row, instance, result, **flags):
+                methods.add('append_result')
+                return EXT_CONTINUE
+
+            def populate_instance(self, mapper, selectcontext, row, instance, **flags):
+                methods.add('populate_instance')
+                return EXT_CONTINUE
+
+            def before_insert(self, mapper, connection, instance):
+                methods.add('before_insert')
+                return EXT_CONTINUE
+
+            def after_insert(self, mapper, connection, instance):
+                methods.add('after_insert')
+                return EXT_CONTINUE
+
+            def before_update(self, mapper, connection, instance):
+                methods.add('before_update')
+                return EXT_CONTINUE
+
+            def after_update(self, mapper, connection, instance):
+                methods.add('after_update')
+                return EXT_CONTINUE
+
+            def before_delete(self, mapper, connection, instance):
+                methods.add('before_delete')
+                return EXT_CONTINUE
+
+            def after_delete(self, mapper, connection, instance):
+                methods.add('after_delete')
+                return EXT_CONTINUE
+        
+        mapper(User, users, extension=Ext())
+        sess = create_session()
+        u = User()
+        sess.save(u)
+        sess.flush()
+        u = sess.query(User).load(u.user_id)
+        sess.clear()
+        u = sess.query(User).get(u.user_id)
+        u.user_name = 'foobar'
+        sess.flush()
+        sess.delete(u)
+        sess.flush()
+        assert methods == set(['load', 'append_result', 'before_delete', 'create_instance', 'translate_row', 'get', 
+                'after_delete', 'after_insert', 'before_update', 'before_insert', 'after_update', 'populate_instance'])
+        
 
 if __name__ == "__main__":    
     testbase.main()