From: Mike Bayer Date: Fri, 20 Jul 2007 16:52:00 +0000 (+0000) Subject: changed assignmapper API per [ticket:636] X-Git-Tag: rel_0_4_6~63 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a4ec7c80ee746921da8a8da54773a5596a780d96;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git changed assignmapper API per [ticket:636] --- diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py index f78e464933..03c6cc1f83 100644 --- a/lib/sqlalchemy/ext/assignmapper.py +++ b/lib/sqlalchemy/ext/assignmapper.py @@ -1,18 +1,8 @@ from sqlalchemy import util, exceptions import types -from sqlalchemy.orm import mapper, Query - -def monkeypatch_query_method(ctx, class_, name): - def do(self, *args, **kwargs): - query = Query(class_, session=ctx.current) - return getattr(query, name)(*args, **kwargs) - try: - do.__name__ = name - except: - pass - setattr(class_, name, classmethod(do)) - -def monkeypatch_objectstore_method(ctx, class_, name): +from sqlalchemy.orm import mapper + +def _monkeypatch_session_method(name, ctx, class_): def do(self, *args, **kwargs): session = ctx.current return getattr(session, name)(self, *args, **kwargs) @@ -21,31 +11,37 @@ def monkeypatch_objectstore_method(ctx, class_, name): except: pass setattr(class_, name, do) - - + def assign_mapper(ctx, class_, *args, **kwargs): - validate = kwargs.pop('validate', False) - if not isinstance(getattr(class_, '__init__'), types.MethodType): - def __init__(self, **kwargs): - if validate: - keys = [p.key for p in self.mapper.iterate_properties] - for key, value in kwargs.items(): - if validate and key not in keys: - raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key) - setattr(self, key, value) - class_.__init__ = __init__ extension = kwargs.pop('extension', None) if extension is not None: extension = util.to_list(extension) extension.append(ctx.mapper_extension) else: extension = ctx.mapper_extension + + validate = kwargs.pop('validate', False) + + if not isinstance(getattr(class_, '__init__'), types.MethodType): + def __init__(self, **kwargs): + for key, value in kwargs.items(): + if validate: + if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False): + raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key) + setattr(self, key, value) + class_.__init__ = __init__ + + class query(object): + def __getattr__(self, key): + return getattr(ctx.current.query(class_), key) + def __call__(self): + return ctx.current.query(class_) + class_.query = query() + + for name in ['refresh', 'expire', 'delete', 'expunge', 'update']: + _monkeypatch_session_method(name, ctx, class_) + m = mapper(class_, extension=extension, *args, **kwargs) class_.mapper = m - class_.query = classmethod(lambda cls: Query(class_, session=ctx.current)) - for name in ['get', 'filter', 'filter_by', 'select', 'select_by', 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by', 'get_by', 'join', 'count', 'count_by', 'options', 'instances']: - monkeypatch_query_method(ctx, class_, name) - for name in ['delete', 'expire', 'refresh', 'expunge', 'save', 'update', 'save_or_update']: - monkeypatch_objectstore_method(ctx, class_, name) return m diff --git a/test/ext/assignmapper.py b/test/ext/assignmapper.py index 8562a016a2..dbad8de9d1 100644 --- a/test/ext/assignmapper.py +++ b/test/ext/assignmapper.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.assignmapper import assign_mapper from sqlalchemy.ext.sessioncontext import SessionContext from testbase import Table, Column -class OverrideAttributesTest(PersistTest): +class AssignMapperTest(PersistTest): def setUpAll(self): global metadata, table, table2 metadata = MetaData(testbase.db) @@ -20,13 +20,9 @@ class OverrideAttributesTest(PersistTest): Column('someid', None, ForeignKey('sometable.id')) ) metadata.create_all() - def tearDownAll(self): - metadata.drop_all() - def tearDown(self): - clear_mappers() + def setUp(self): - pass - def test_override_attributes(self): + global SomeObject, SomeOtherObject, ctx class SomeObject(object):pass class SomeOtherObject(object):pass @@ -35,7 +31,7 @@ class OverrideAttributesTest(PersistTest): 'options':relation(SomeOtherObject) }) assign_mapper(ctx, SomeOtherObject, table2) - class_mapper(SomeObject) + s = SomeObject() s.id = 1 s.data = 'hello' @@ -43,8 +39,17 @@ class OverrideAttributesTest(PersistTest): s.options.append(sso) ctx.current.flush() ctx.current.clear() + + def tearDownAll(self): + metadata.drop_all() + def tearDown(self): + clear_mappers() + + def test_override_attributes(self): - assert SomeObject.get_by(id=s.id).options[0].id == sso.id + sso = SomeOtherObject.query().first() + + assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id s2 = SomeObject(someid=12) s3 = SomeOtherObject(someid=123, bogus=345) @@ -58,6 +63,8 @@ class OverrideAttributesTest(PersistTest): assert False except exceptions.ArgumentError: pass + + if __name__ == '__main__': testbase.main()