From: Mike Bayer Date: Fri, 3 Aug 2007 19:31:38 +0000 (+0000) Subject: - removed enhance_classes from scoped_session, replaced with X-Git-Tag: rel_0_4beta1~88 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e7c83bb37133af7b0deaef2fbc0d0fae8a179dfc;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - removed enhance_classes from scoped_session, replaced with scoped_session(...).mapper. 'mapper' essentially does the same thing as assign_mapper less verbosely. - adapted assignmapper unit tests into scoped_session tests --- diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 96d9a23fc0..5d11a99a4a 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -1,4 +1,4 @@ -from sqlalchemy.util import ScopedRegistry, warn_deprecated +from sqlalchemy.util import ScopedRegistry, warn_deprecated, to_list from sqlalchemy.orm import MapperExtension, EXT_CONTINUE from sqlalchemy.orm.session import Session from sqlalchemy.orm.mapper import global_extensions @@ -13,16 +13,21 @@ class ScopedSession(object): Usage:: - Session = scoped_session(sessionmaker(autoflush=True), enhance_classes=True) + Session = scoped_session(sessionmaker(autoflush=True)) + + To map classes so that new instances are saved in the current + Session automatically, as well as to provide session-aware + class attributes such as "query": + + mapper = Session.mapper + mapper(Class, table, ...) """ - def __init__(self, session_factory, scopefunc=None, enhance_classes=False): + def __init__(self, session_factory, scopefunc=None): self.session_factory = session_factory - self.enhance_classes = enhance_classes self.registry = ScopedRegistry(session_factory, scopefunc) - if self.enhance_classes: - global_extensions.append(_ScopedExt(self)) + self.extension = _ScopedExt(self) def __call__(self, **kwargs): if kwargs: @@ -39,15 +44,28 @@ class ScopedSession(object): else: return self.registry() + def mapper(self, *args, **kwargs): + """return a mapper() function which associates this ScopedSession with the Mapper.""" + + from sqlalchemy.orm import mapper + validate = kwargs.pop('validate', False) + extension = to_list(kwargs.setdefault('extension', [])) + if validate: + extension.append(self.extension.validating()) + else: + extension.append(self.extension) + return mapper(*args, **kwargs) + def configure(self, **kwargs): - """reconfigure the sessionmaker used by this SessionContext""" + """reconfigure the sessionmaker used by this ScopedSession.""" + self.session_factory.configure(**kwargs) def instrument(name): def do(self, *args, **kwargs): return getattr(self.registry(), name)(*args, **kwargs) return do -for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete'): +for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete', 'clear'): setattr(ScopedSession, meth, instrument(meth)) def makeprop(name): @@ -67,18 +85,22 @@ for prop in ('close_all',): setattr(ScopedSession, prop, clslevel(prop)) class _ScopedExt(MapperExtension): - def __init__(self, context): + def __init__(self, context, validate=False): self.context = context + self.validate = validate + + def validating(self): + return _ScopedExt(self.context, validate=True) def get_session(self): return self.context.registry() def instrument_class(self, mapper, class_): class query(object): - def __getattr__(self, key): - return getattr(registry().query(class_), key) - def __call__(self): - return registry().query(class_) + def __getattr__(s, key): + return getattr(self.context.registry().query(class_), key) + def __call__(s): + return self.context.registry().query(class_) if not hasattr(class_, 'query'): class_.query = query() @@ -87,9 +109,9 @@ class _ScopedExt(MapperExtension): session = kwargs.pop('_sa_session', self.context.registry()) if not isinstance(oldinit, types.MethodType): for key, value in kwargs.items(): - #if validate: - # if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False): - # raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key) + if self.validate: + if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False): + raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key) setattr(instance, key, value) session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None)) return EXT_CONTINUE diff --git a/test/orm/session.py b/test/orm/session.py index d3eed5c570..0b56b84d4f 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -4,7 +4,6 @@ from sqlalchemy.orm import * from testlib import * from testlib.tables import * import testlib.tables as tables -from sqlalchemy.orm.session import Session class SessionTest(AssertMixin): def setUpAll(self): @@ -98,7 +97,7 @@ class SessionTest(AssertMixin): conn1 = testbase.db.connect() conn2 = testbase.db.connect() - sess = Session(bind=conn1, transactional=True, autoflush=True) + sess = create_session(bind=conn1, transactional=True, autoflush=True) u = User() u.user_name='ed' sess.save(u) @@ -116,7 +115,7 @@ class SessionTest(AssertMixin): mapper(User, users) try: - sess = Session(transactional=True, autoflush=True) + sess = create_session(transactional=True, autoflush=True) u = User() u.user_name='ed' sess.save(u) @@ -137,7 +136,7 @@ class SessionTest(AssertMixin): conn1 = testbase.db.connect() conn2 = testbase.db.connect() - sess = Session(bind=conn1, transactional=True, autoflush=True) + sess = create_session(bind=conn1, transactional=True, autoflush=True) u = User() u.user_name='ed' sess.save(u) @@ -153,7 +152,7 @@ class SessionTest(AssertMixin): 'addresses':relation(Address) }) - sess = Session(transactional=True, autoflush=True) + sess = create_session(transactional=True, autoflush=True) u = sess.query(User).get(8) newad = Address() newad.email_address == 'something new' @@ -173,7 +172,7 @@ class SessionTest(AssertMixin): mapper(User, users) conn = testbase.db.connect() trans = conn.begin() - sess = Session(conn, transactional=True, autoflush=True) + sess = create_session(bind=conn, transactional=True, autoflush=True) sess.begin() u = User() sess.save(u) @@ -189,7 +188,7 @@ class SessionTest(AssertMixin): try: conn = testbase.db.connect() trans = conn.begin() - sess = Session(conn, transactional=True, autoflush=True) + sess = create_session(bind=conn, transactional=True, autoflush=True) u1 = User() sess.save(u1) sess.flush() @@ -217,7 +216,7 @@ class SessionTest(AssertMixin): mapper(Address, addresses) engine2 = create_engine(testbase.db.url) - sess = Session(transactional=False, autoflush=False, twophase=True) + sess = create_session(transactional=False, autoflush=False, twophase=True) sess.bind_mapper(User, testbase.db) sess.bind_mapper(Address, engine2) sess.begin() @@ -234,7 +233,7 @@ class SessionTest(AssertMixin): def test_joined_transaction(self): class User(object):pass mapper(User, users) - sess = Session(transactional=True, autoflush=True) + sess = create_session(transactional=True, autoflush=True) sess.begin() u = User() sess.save(u) @@ -440,6 +439,75 @@ class SessionTest(AssertMixin): key = s.identity_key(User, row=row, entity_name="en") self._assert_key(key, (User, (1,), "en")) +class ScopedSessionTest(PersistTest): + def setUpAll(self): + global metadata, table, table2 + metadata = MetaData(testbase.db) + table = Table('sometable', metadata, + Column('id', Integer, primary_key=True), + Column('data', String(30))) + table2 = Table('someothertable', metadata, + Column('id', Integer, primary_key=True), + Column('someid', None, ForeignKey('sometable.id')) + ) + metadata.create_all() + + def setUp(self): + global SomeObject, SomeOtherObject + class SomeObject(object):pass + class SomeOtherObject(object):pass + global Session + + Session = scoped_session(create_session) + Session.mapper(SomeObject, table, properties={ + 'options':relation(SomeOtherObject) + }) + Session.mapper(SomeOtherObject, table2) + + s = SomeObject() + s.id = 1 + s.data = 'hello' + sso = SomeOtherObject() + s.options.append(sso) + Session.flush() + Session.clear() + + def tearDownAll(self): + metadata.drop_all() + + def tearDown(self): + for table in metadata.table_iterator(reverse=True): + table.delete().execute() + clear_mappers() + + def test_query(self): + sso = SomeOtherObject.query().first() + assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id + + def test_validating_constructor(self): + s2 = SomeObject(someid=12) + s3 = SomeOtherObject(someid=123, bogus=345) + + class ValidatedOtherObject(object):pass + Session.mapper(ValidatedOtherObject, table2, validate=True) + + v1 = ValidatedOtherObject(someid=12) + try: + v2 = ValidatedOtherObject(someid=12, bogus=345) + assert False + except exceptions.ArgumentError: + pass + + def test_dont_clobber_methods(self): + class MyClass(object): + def expunge(self): + return "an expunge !" + + Session.mapper(MyClass, table2) + + assert MyClass().expunge() == "an expunge !" + + if __name__ == "__main__": testbase.main() diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index f28e428ed9..f065267f7d 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -12,8 +12,9 @@ from testlib import tables class UnitOfWorkTest(AssertMixin): def setUpAll(self): - global Session - Session = scoped_session(sessionmaker(autoflush=True, transactional=True), enhance_classes=True) + global Session, mapper + Session = scoped_session(sessionmaker(autoflush=True, transactional=True)) + mapper = Session.mapper def tearDownAll(self): global_extensions[:] = [] def tearDown(self):