-from sqlalchemy.util import ScopedRegistry, warn_deprecated
+from sqlalchemy.util import ScopedRegistry, warn_deprecated, to_list
from sqlalchemy.orm import MapperExtension, EXT_CONTINUE
from sqlalchemy.orm.session import Session
from sqlalchemy.orm.mapper import global_extensions
Usage::
- Session = scoped_session(sessionmaker(autoflush=True), enhance_classes=True)
+ Session = scoped_session(sessionmaker(autoflush=True))
+
+ To map classes so that new instances are saved in the current
+ Session automatically, as well as to provide session-aware
+ class attributes such as "query":
+
+ mapper = Session.mapper
+ mapper(Class, table, ...)
"""
- def __init__(self, session_factory, scopefunc=None, enhance_classes=False):
+ def __init__(self, session_factory, scopefunc=None):
self.session_factory = session_factory
- self.enhance_classes = enhance_classes
self.registry = ScopedRegistry(session_factory, scopefunc)
- if self.enhance_classes:
- global_extensions.append(_ScopedExt(self))
+ self.extension = _ScopedExt(self)
def __call__(self, **kwargs):
if kwargs:
else:
return self.registry()
+ def mapper(self, *args, **kwargs):
+ """return a mapper() function which associates this ScopedSession with the Mapper."""
+
+ from sqlalchemy.orm import mapper
+ validate = kwargs.pop('validate', False)
+ extension = to_list(kwargs.setdefault('extension', []))
+ if validate:
+ extension.append(self.extension.validating())
+ else:
+ extension.append(self.extension)
+ return mapper(*args, **kwargs)
+
def configure(self, **kwargs):
- """reconfigure the sessionmaker used by this SessionContext"""
+ """reconfigure the sessionmaker used by this ScopedSession."""
+
self.session_factory.configure(**kwargs)
def instrument(name):
def do(self, *args, **kwargs):
return getattr(self.registry(), name)(*args, **kwargs)
return do
-for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete'):
+for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete', 'clear'):
setattr(ScopedSession, meth, instrument(meth))
def makeprop(name):
setattr(ScopedSession, prop, clslevel(prop))
class _ScopedExt(MapperExtension):
- def __init__(self, context):
+ def __init__(self, context, validate=False):
self.context = context
+ self.validate = validate
+
+ def validating(self):
+ return _ScopedExt(self.context, validate=True)
def get_session(self):
return self.context.registry()
def instrument_class(self, mapper, class_):
class query(object):
- def __getattr__(self, key):
- return getattr(registry().query(class_), key)
- def __call__(self):
- return registry().query(class_)
+ def __getattr__(s, key):
+ return getattr(self.context.registry().query(class_), key)
+ def __call__(s):
+ return self.context.registry().query(class_)
if not hasattr(class_, 'query'):
class_.query = query()
session = kwargs.pop('_sa_session', self.context.registry())
if not isinstance(oldinit, types.MethodType):
for key, value in kwargs.items():
- #if validate:
- # if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
- # raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
+ if self.validate:
+ if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
+ raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
setattr(instance, key, value)
session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
return EXT_CONTINUE
from testlib import *
from testlib.tables import *
import testlib.tables as tables
-from sqlalchemy.orm.session import Session
class SessionTest(AssertMixin):
def setUpAll(self):
conn1 = testbase.db.connect()
conn2 = testbase.db.connect()
- sess = Session(bind=conn1, transactional=True, autoflush=True)
+ sess = create_session(bind=conn1, transactional=True, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
mapper(User, users)
try:
- sess = Session(transactional=True, autoflush=True)
+ sess = create_session(transactional=True, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
conn1 = testbase.db.connect()
conn2 = testbase.db.connect()
- sess = Session(bind=conn1, transactional=True, autoflush=True)
+ sess = create_session(bind=conn1, transactional=True, autoflush=True)
u = User()
u.user_name='ed'
sess.save(u)
'addresses':relation(Address)
})
- sess = Session(transactional=True, autoflush=True)
+ sess = create_session(transactional=True, autoflush=True)
u = sess.query(User).get(8)
newad = Address()
newad.email_address == 'something new'
mapper(User, users)
conn = testbase.db.connect()
trans = conn.begin()
- sess = Session(conn, transactional=True, autoflush=True)
+ sess = create_session(bind=conn, transactional=True, autoflush=True)
sess.begin()
u = User()
sess.save(u)
try:
conn = testbase.db.connect()
trans = conn.begin()
- sess = Session(conn, transactional=True, autoflush=True)
+ sess = create_session(bind=conn, transactional=True, autoflush=True)
u1 = User()
sess.save(u1)
sess.flush()
mapper(Address, addresses)
engine2 = create_engine(testbase.db.url)
- sess = Session(transactional=False, autoflush=False, twophase=True)
+ sess = create_session(transactional=False, autoflush=False, twophase=True)
sess.bind_mapper(User, testbase.db)
sess.bind_mapper(Address, engine2)
sess.begin()
def test_joined_transaction(self):
class User(object):pass
mapper(User, users)
- sess = Session(transactional=True, autoflush=True)
+ sess = create_session(transactional=True, autoflush=True)
sess.begin()
u = User()
sess.save(u)
key = s.identity_key(User, row=row, entity_name="en")
self._assert_key(key, (User, (1,), "en"))
+class ScopedSessionTest(PersistTest):
+ def setUpAll(self):
+ global metadata, table, table2
+ metadata = MetaData(testbase.db)
+ table = Table('sometable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(30)))
+ table2 = Table('someothertable', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('someid', None, ForeignKey('sometable.id'))
+ )
+ metadata.create_all()
+
+ def setUp(self):
+ global SomeObject, SomeOtherObject
+ class SomeObject(object):pass
+ class SomeOtherObject(object):pass
+ global Session
+
+ Session = scoped_session(create_session)
+ Session.mapper(SomeObject, table, properties={
+ 'options':relation(SomeOtherObject)
+ })
+ Session.mapper(SomeOtherObject, table2)
+
+ s = SomeObject()
+ s.id = 1
+ s.data = 'hello'
+ sso = SomeOtherObject()
+ s.options.append(sso)
+ Session.flush()
+ Session.clear()
+
+ def tearDownAll(self):
+ metadata.drop_all()
+
+ def tearDown(self):
+ for table in metadata.table_iterator(reverse=True):
+ table.delete().execute()
+ clear_mappers()
+
+ def test_query(self):
+ sso = SomeOtherObject.query().first()
+ assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
+
+ def test_validating_constructor(self):
+ s2 = SomeObject(someid=12)
+ s3 = SomeOtherObject(someid=123, bogus=345)
+
+ class ValidatedOtherObject(object):pass
+ Session.mapper(ValidatedOtherObject, table2, validate=True)
+
+ v1 = ValidatedOtherObject(someid=12)
+ try:
+ v2 = ValidatedOtherObject(someid=12, bogus=345)
+ assert False
+ except exceptions.ArgumentError:
+ pass
+
+ def test_dont_clobber_methods(self):
+ class MyClass(object):
+ def expunge(self):
+ return "an expunge !"
+
+ Session.mapper(MyClass, table2)
+
+ assert MyClass().expunge() == "an expunge !"
+
+
if __name__ == "__main__":
testbase.main()
class UnitOfWorkTest(AssertMixin):
def setUpAll(self):
- global Session
- Session = scoped_session(sessionmaker(autoflush=True, transactional=True), enhance_classes=True)
+ global Session, mapper
+ Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
+ mapper = Session.mapper
def tearDownAll(self):
global_extensions[:] = []
def tearDown(self):