From: Mike Bayer Date: Mon, 29 Oct 2007 16:12:29 +0000 (+0000) Subject: - restored MapperExtension functionality for [ticket:829], added test coverage X-Git-Tag: rel_0_4_1~99 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=adcf8ea00dda6d8e62fceeab36d90eabe36b5f91;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - restored MapperExtension functionality for [ticket:829], added test coverage - changed naming convention in mapper.py tests to test_ --- diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index f4294502b1..9e3f20257f 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -122,8 +122,11 @@ class ExtensionCarrier(object): """ def __init__(self, _elements=None): - self.__elements = _elements or [] self.methods = {} + if _elements is not None: + self.__elements = [self.__inspect(e) for e in _elements] + else: + self.__elements = [] def copy(self): return ExtensionCarrier(list(self.__elements)) diff --git a/test/orm/mapper.py b/test/orm/mapper.py index 9f11027af9..c9729944af 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -23,7 +23,7 @@ class MapperSuperTest(AssertMixin): class MapperTest(MapperSuperTest): - def testpropconflict(self): + def test_propconflict(self): """test that a backref created against an existing mapper with a property name conflict raises a decent error message""" mapper(Address, addresses) @@ -38,7 +38,7 @@ class MapperTest(MapperSuperTest): except exceptions.ArgumentError: pass - def testbadcascade(self): + def test_badcascade(self): mapper(Address, addresses) try: mapper(User, users, properties={'addresses':relation(Address, cascade="fake, all, delete-orphan")}) @@ -46,7 +46,7 @@ class MapperTest(MapperSuperTest): except exceptions.ArgumentError, e: assert str(e) == "Invalid cascade option 'fake'" - def testcolumnprefix(self): + def test_columnprefix(self): mapper(User, users, column_prefix='_', properties={ 'user_name':synonym('_user_name') }) @@ -59,7 +59,7 @@ class MapperTest(MapperSuperTest): u2 = s.query(User).filter_by(user_name='jack').one() assert u is u2 - def testrefresh(self): + def test_refresh(self): mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), backref='user')}) s = create_session() u = s.get(User, 7) @@ -92,12 +92,12 @@ class MapperTest(MapperSuperTest): self.assert_(u.user_name == 'jack') self.assert_(a not in u.addresses) - def testcompileonsession(self): + def test_compileonsession(self): m = mapper(User, users) session = create_session() session.connection(m) - def testexpirecascade(self): + def test_expirecascade(self): mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), cascade="all, refresh-expire")}) s = create_session() u = s.get(User, 8) @@ -105,7 +105,7 @@ class MapperTest(MapperSuperTest): s.expire(u) assert u.addresses[0].email_address == 'ed@wood.com' - def testrefreshwitheager(self): + def test_refreshwitheager(self): """test that a refresh/expire operation loads rows properly and sends correct "isnew" state to eager loaders""" mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)}) s = create_session() @@ -120,7 +120,7 @@ class MapperTest(MapperSuperTest): s.expire(u) assert len(u.addresses) == 3 - def testincompletecolumns(self): + def test_incompletecolumns(self): """test loading from a select which does not contain all columns""" mapper(Address, addresses) s = create_session() @@ -129,7 +129,7 @@ class MapperTest(MapperSuperTest): assert a.address_id == 1 assert a.email_address is None - def testbadconstructor(self): + def test_badconstructor(self): """test that if the construction of a mapped class fails, the instnace does not get placed in the session""" class Foo(object): def __init__(self, one, two): @@ -147,7 +147,7 @@ class MapperTest(MapperSuperTest): except TypeError, e: pass - def testconstructorexceptions(self): + def test_constructorexceptions(self): """test that exceptions raised in the mapped class are not masked by sa decorations""" ex = AssertionError('oops') sess = create_session() @@ -178,7 +178,7 @@ class MapperTest(MapperSuperTest): except Exception, e: assert e is ex - def testrefresh_lazy(self): + def test_refresh_lazy(self): """test that when a lazy loader is set as a trigger on an object's attribute (at the attribute level, not the class level), a refresh() operation doesnt fire the lazy loader or create any problems""" s = create_session() mapper(User, users, properties={'addresses':relation(mapper(Address, addresses))}) @@ -188,7 +188,7 @@ class MapperTest(MapperSuperTest): s.refresh(u) self.assert_sql_count(testbase.db, go, 1) - def testexpire(self): + def test_expire(self): """test the expire function""" s = create_session() mapper(User, users, properties={'addresses':relation(mapper(Address, addresses), lazy=False)}) @@ -219,7 +219,7 @@ class MapperTest(MapperSuperTest): # this should *not* produce a SELECT statement (not tested here though....) self.assert_(u.user_name =='jack') - def testrefresh2(self): + def test_refresh2(self): """test a hang condition that was occuring on expire/refresh""" s = create_session() @@ -242,13 +242,13 @@ class MapperTest(MapperSuperTest): s.refresh(u) #hangs - def testprops(self): + def test_props(self): m = mapper(User, users, properties = { 'addresses' : relation(mapper(Address, addresses)) }).compile() self.assert_(User.addresses.property is m.get_property('addresses')) - def testcompileonprop(self): + def test_compileonprop(self): mapper(User, users, properties = { 'addresses' : relation(mapper(Address, addresses)) }) @@ -267,7 +267,7 @@ class MapperTest(MapperSuperTest): mapper(Foo, addresses, inherits=User) assert getattr(Foo().__class__, 'user_name').impl is not None - def testaddproperty(self): + def test_addproperty(self): m = mapper(User, users) mapper(Address, addresses) m.add_property('user_name', deferred(users.c.user_name)) @@ -291,7 +291,7 @@ class MapperTest(MapperSuperTest): sess.flush() sess.rollback() - def testpropfilters(self): + def test_propfilters(self): t = Table('person', MetaData(), Column('id', Integer, primary_key=True), Column('type', String), @@ -342,7 +342,7 @@ class MapperTest(MapperSuperTest): assert_props(Hoho, ['id', 'name', 'type']) assert_props(Lala, ['p_employee_number', 'p_id', 'p_name', 'p_type']) - def testrecursiveselectby(self): + def test_recursiveselectby(self): """test that no endless loop occurs when traversing for select_by""" m = mapper(User, users, properties={ 'orders':relation(mapper(Order, orders), backref='user'), @@ -351,7 +351,7 @@ class MapperTest(MapperSuperTest): q = create_session().query(m) q.select_by(email_address='foo') - def testmappingtojoin(self): + def test_mappingtojoin(self): """test mapping to a join""" usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id) m = mapper(User, usersaddresses, primary_key=[users.c.user_id]) @@ -359,7 +359,7 @@ class MapperTest(MapperSuperTest): l = q.select() self.assert_result(l, User, *user_result[0:2]) - def testmappingtojoinnopk(self): + def test_mappingtojoinnopk(self): metadata = MetaData() account_ids_table = Table('account_ids', metadata, Column('account_id', Integer, primary_key=True), @@ -383,7 +383,7 @@ class MapperTest(MapperSuperTest): finally: metadata.drop_all(testbase.db) - def testmappingtoouterjoin(self): + def test_mappingtoouterjoin(self): """test mapping to an outer join, with a composite primary key that allows nulls""" result = [ {'user_id' : 7, 'address_id' : 1}, @@ -400,7 +400,7 @@ class MapperTest(MapperSuperTest): self.assert_result(l, User, *result) - def testcustomjoin(self): + def test_customjoin(self): """test that the from_obj parameter to query.select() can be used to totally replace the FROM parameters of the generated query.""" @@ -414,7 +414,7 @@ class MapperTest(MapperSuperTest): l = q.select((orderitems.c.item_name=='item 4'), from_obj=[users.join(orders).join(orderitems)]) self.assert_result(l, User, user_result[0]) - def testorderby(self): + def test_orderby(self): """test ordering at the mapper and query level""" # TODO: make a unit test out of these various combinations # m = mapper(User, users, order_by=desc(users.c.user_name)) @@ -428,7 +428,7 @@ class MapperTest(MapperSuperTest): @testing.unsupported('firebird') - def testfunction(self): + def test_function(self): """test mapping to a SELECT statement that has functions in it.""" s = select([users, (users.c.user_id * 2).label('concat'), func.count(addresses.c.address_id).label('count')], users.c.user_id==addresses.c.user_id, group_by=[c for c in users.c]).alias('myselect') @@ -441,7 +441,7 @@ class MapperTest(MapperSuperTest): assert l[1].concat == l[1].user_id * 2 == 16 @testing.unsupported('firebird') - def testcount(self): + def test_count(self): """test the count function on Query. (why doesnt this work on firebird?)""" @@ -451,14 +451,14 @@ class MapperTest(MapperSuperTest): self.assert_(q.count(users.c.user_id.in_([8,9]))==2) self.assert_(q.count_by(user_name='fred')==1) - def testmanytomany_count(self): + def test_manytomany_count(self): mapper(Item, orderitems, properties = dict( keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True), )) q = create_session().query(Item) assert q.join('keywords').distinct().count(Keyword.c.name=="red") == 2 - def testoverride(self): + def test_override(self): # assert that overriding a column raises an error try: m = mapper(User, users, properties = { @@ -481,7 +481,7 @@ class MapperTest(MapperSuperTest): 'foo' : users.c.user_name, }) - def testsynonym(self): + def test_synonym(self): sess = create_session() mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = True), @@ -510,7 +510,7 @@ class MapperTest(MapperSuperTest): assert u.user_name == "some user name" assert u in sess.dirty - def testsynonymoptions(self): + def test_synonymoptions(self): sess = create_session() mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = True), @@ -522,7 +522,7 @@ class MapperTest(MapperSuperTest): self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1])) self.assert_sql_count(testbase.db, go, 1) - def testextensionoptions(self): + def test_extensionoptions(self): sess = create_session() class ext1(MapperExtension): def populate_instance(self, mapper, selectcontext, row, instance, **flags): @@ -549,7 +549,7 @@ class MapperTest(MapperSuperTest): assert not hasattr(l.addresses[0], 'TEST') assert not hasattr(l.addresses[0], 'TEST2') - def testeageroptions(self): + def test_eageroptions(self): """tests that a lazy relation can be upgraded to an eager relation via the options method""" sess = create_session() mapper(User, users, properties = dict( @@ -561,7 +561,7 @@ class MapperTest(MapperSuperTest): self.assert_result(l, User, *user_address_result) self.assert_sql_count(testbase.db, go, 0) - def testeageroptionswithlimit(self): + def test_eageroptionswithlimit(self): sess = create_session() mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = True) @@ -582,7 +582,7 @@ class MapperTest(MapperSuperTest): assert len(u.addresses) == 3 assert "tbl_row_count" not in self.capture_sql(testbase.db, go) - def testlazyoptionswithlimit(self): + def test_lazyoptionswithlimit(self): sess = create_session() mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy = False) @@ -594,7 +594,7 @@ class MapperTest(MapperSuperTest): assert len(u.addresses) == 3 self.assert_sql_count(testbase.db, go, 1) - def testeagerdegrade(self): + def test_eagerdegrade(self): """tests that an eager relation automatically degrades to a lazy relation if eager columns are not available""" sess = create_session() usermapper = mapper(User, users, properties = dict( @@ -656,7 +656,7 @@ class MapperTest(MapperSuperTest): self.assert_sql_count(testbase.db, go, 7) - def testlazyoptions(self): + def test_lazyoptions(self): """tests that an eager relation can be upgraded to a lazy relation via the options method""" sess = create_session() mapper(User, users, properties = dict( @@ -667,7 +667,7 @@ class MapperTest(MapperSuperTest): self.assert_result(l, User, *user_address_result) self.assert_sql_count(testbase.db, go, 3) - def testlatecompile(self): + def test_latecompile(self): """tests mappers compiling late in the game""" mapper(User, users, properties = {'orders': relation(Order)}) @@ -681,7 +681,7 @@ class MapperTest(MapperSuperTest): print u[0].orders[1].items[0].keywords[1] self.assert_sql_count(testbase.db, go, 3) - def testdeepoptions(self): + def test_deepoptions(self): mapper(User, users, properties = { 'orders': relation(mapper(Order, orders, properties = { @@ -733,7 +733,7 @@ class MapperTest(MapperSuperTest): class DeferredTest(MapperSuperTest): - def testbasic(self): + def test_basic(self): """tests a basic "deferred" load""" m = mapper(Order, orders, properties={ @@ -755,7 +755,7 @@ class DeferredTest(MapperSuperTest): ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :param_1", {'param_1':3}) ]) - def testunsaved(self): + def test_unsaved(self): """test that deferred loading doesnt kick in when just PK cols are set""" m = mapper(Order, orders, properties={ 'description':deferred(orders.c.description) @@ -769,7 +769,7 @@ class DeferredTest(MapperSuperTest): o.description = "some description" self.assert_sql_count(testbase.db, go, 0) - def testunsavedgroup(self): + def test_unsavedgroup(self): """test that deferred loading doesnt kick in when just PK cols are set""" m = mapper(Order, orders, properties={ 'description':deferred(orders.c.description, group='primary'), @@ -784,7 +784,7 @@ class DeferredTest(MapperSuperTest): o.description = "some description" self.assert_sql_count(testbase.db, go, 0) - def testsave(self): + def test_save(self): m = mapper(Order, orders, properties={ 'description':deferred(orders.c.description) }) @@ -796,7 +796,7 @@ class DeferredTest(MapperSuperTest): o2.isopen = 1 sess.flush() - def testgroup(self): + def test_group(self): """tests deferred load with a group""" m = mapper(Order, orders, properties = { 'userident':deferred(orders.c.user_id, group='primary'), @@ -827,7 +827,7 @@ class DeferredTest(MapperSuperTest): sess.flush() self.assert_sql_count(testbase.db, go, 0) - def testcommitsstate(self): + def test_commitsstate(self): """test that when deferred elements are loaded via a group, they get the proper CommittedState and dont result in changes being committed""" @@ -849,7 +849,7 @@ class DeferredTest(MapperSuperTest): sess.flush() self.assert_sql_count(testbase.db, go, 0) - def testoptions(self): + def test_options(self): """tests using options on a mapper to create deferred and undeferred columns""" m = mapper(Order, orders) sess = create_session() @@ -873,7 +873,7 @@ class DeferredTest(MapperSuperTest): ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.description AS orders_description, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}), ]) - def testundefergroup(self): + def test_undefergroup(self): """tests undefer_group()""" m = mapper(Order, orders, properties = { 'userident':deferred(orders.c.user_id, group='primary'), @@ -895,7 +895,7 @@ class DeferredTest(MapperSuperTest): ]) - def testdeepoptions(self): + def test_deepoptions(self): m = mapper(User, users, properties={ 'orders':relation(mapper(Order, orders, properties={ 'items':relation(mapper(Item, orderitems, properties={ @@ -1051,7 +1051,7 @@ class CompositeTypesTest(ORMTest): class NoLoadTest(MapperSuperTest): - def testbasic(self): + def test_basic(self): """tests a basic one-to-many lazy load""" m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy=None) @@ -1067,7 +1067,7 @@ class NoLoadTest(MapperSuperTest): self.assert_result(l[0], User, {'user_id' : 7, 'addresses' : (Address, [])}, ) - def testoptions(self): + def test_options(self): m = mapper(User, users, properties = dict( addresses = relation(mapper(Address, addresses), lazy=None) )) @@ -1084,7 +1084,17 @@ class NoLoadTest(MapperSuperTest): ) class MapperExtensionTest(MapperSuperTest): - def testcreateinstance(self): + def setUpAll(self): + tables.create() + def tearDownAll(self): + tables.drop() + def tearDown(self): + clear_mappers() + tables.delete() + def setUp(self): + tables.data() + + def test_create_instance(self): class Ext(MapperExtension): def create_instance(self, *args, **kwargs): return User() @@ -1097,6 +1107,74 @@ class MapperExtensionTest(MapperSuperTest): l = q.select(); self.assert_result(l, User, *user_address_result) + def test_methods(self): + """test that common user-defined methods get called.""" + + methods = set() + class Ext(MapperExtension): + def load(self, query, *args, **kwargs): + methods.add('load') + return EXT_CONTINUE + + def get(self, query, *args, **kwargs): + methods.add('get') + return EXT_CONTINUE + + def translate_row(self, mapper, context, row): + methods.add('translate_row') + return EXT_CONTINUE + + def create_instance(self, mapper, selectcontext, row, class_): + methods.add('create_instance') + return EXT_CONTINUE + + def append_result(self, mapper, selectcontext, row, instance, result, **flags): + methods.add('append_result') + return EXT_CONTINUE + + def populate_instance(self, mapper, selectcontext, row, instance, **flags): + methods.add('populate_instance') + return EXT_CONTINUE + + def before_insert(self, mapper, connection, instance): + methods.add('before_insert') + return EXT_CONTINUE + + def after_insert(self, mapper, connection, instance): + methods.add('after_insert') + return EXT_CONTINUE + + def before_update(self, mapper, connection, instance): + methods.add('before_update') + return EXT_CONTINUE + + def after_update(self, mapper, connection, instance): + methods.add('after_update') + return EXT_CONTINUE + + def before_delete(self, mapper, connection, instance): + methods.add('before_delete') + return EXT_CONTINUE + + def after_delete(self, mapper, connection, instance): + methods.add('after_delete') + return EXT_CONTINUE + + mapper(User, users, extension=Ext()) + sess = create_session() + u = User() + sess.save(u) + sess.flush() + u = sess.query(User).load(u.user_id) + sess.clear() + u = sess.query(User).get(u.user_id) + u.user_name = 'foobar' + sess.flush() + sess.delete(u) + sess.flush() + assert methods == set(['load', 'append_result', 'before_delete', 'create_instance', 'translate_row', 'get', + 'after_delete', 'after_insert', 'before_update', 'before_insert', 'after_update', 'populate_instance']) + if __name__ == "__main__": testbase.main()