# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-import weakref
+import weakref, types
from sqlalchemy import util, exceptions, sql, engine
-from sqlalchemy.orm import unitofwork, query, util as mapperutil
+from sqlalchemy.orm import unitofwork, query, util as mapperutil, MapperExtension, EXT_CONTINUE
from sqlalchemy.orm.mapper import object_mapper as _object_mapper
from sqlalchemy.orm.mapper import class_mapper as _class_mapper
+from sqlalchemy.orm.mapper import global_extensions
+__all__ = ['Session', 'SessionTransaction']
+
+def sessionmaker(autoflush, transactional, bind=None, scope=None, enhance_classes=False, **kwargs):
+ """Generate a Session configuration."""
+
+ if enhance_classes and scope is None:
+ raise exceptions.InvalidRequestError("enhance_classes requires a non-None 'scope' argument, so that mappers can automatically locate a Session already in progress.")
+
+ class Sess(Session):
+ def __init__(self, **local_kwargs):
+ local_kwargs.setdefault('bind', bind)
+ local_kwargs.setdefault('autoflush', autoflush)
+ local_kwargs.setdefault('transactional', transactional)
+ for k in kwargs:
+ local_kwargs.setdefault(k, kwargs[k])
+ super(Sess, self).__init__(**local_kwargs)
+
+ if scope=="thread":
+ registry = util.ScopedRegistry(Sess, scopefunc=None)
+
+ if enhance_classes:
+ class SessionContextExt(MapperExtension):
+ def get_session(self):
+ return registry()
+
+ def instrument_class(self, mapper, class_):
+ class query(object):
+ def __getattr__(self, key):
+ return getattr(registry().query(class_), key)
+ def __call__(self):
+ return registry().query(class_)
+
+ if not hasattr(class_, 'query'):
+ class_.query = query()
+
+ def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
+ session = kwargs.pop('_sa_session', registry())
+ if not isinstance(oldinit, types.MethodType):
+ for key, value in kwargs.items():
+ #if validate:
+ # if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
+ # raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
+ setattr(instance, key, value)
+ session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
+ return EXT_CONTINUE
+
+ def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
+ object_session(instance).expunge(instance)
+ return EXT_CONTINUE
+
+ def dispose_class(self, mapper, class_):
+ if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
+ if class_.__init__._oldinit is not None:
+ class_.__init__ = class_.__init__._oldinit
+ else:
+ delattr(class_, '__init__')
+ if hasattr(class_, 'query'):
+ delattr(class_, 'query')
+
+ global_extensions.append(SessionContextExt())
+
+ default_scope=scope
+ class ScopedSess(Sess):
+ def __new__(cls, **kwargs):
+ if len(kwargs):
+ scope = kwargs.pop('scope', default_scope)
+ if scope is not None:
+ if registry.has():
+ raise exceptions.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
+ else:
+ sess = Sess(**kwargs)
+ registry.set(sess)
+ return sess
+ else:
+ return Sess(**kwargs)
+ else:
+ return registry()
+ def instrument(name):
+ def do(cls, *args, **kwargs):
+ return getattr(registry(), name)(*args, **kwargs)
+ return classmethod(do)
+ for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete'):
+ setattr(ScopedSess, meth, instrument(meth))
+
+ return ScopedSess
+ elif scope is not None:
+ raise exceptions.ArgumentError("Unknown scope '%s'" % scope)
+ else:
+ return session
+
class SessionTransaction(object):
"""Represents a Session-level Transaction.
--- /dev/null
+import types
+
+from sqlalchemy import util, exceptions
+from sqlalchemy.orm.session import Session
+from sqlalchemy.orm import query, util as mapperutil, MapperExtension, EXT_CONTINUE
+from sqlalchemy.orm.mapper import global_extensions
+
+def sessionmaker(autoflush, transactional, bind=None, scope=None, enhance_classes=False, **kwargs):
+ """Generate a Session configuration."""
+
+ if enhance_classes and scope is None:
+ raise exceptions.InvalidRequestError("enhance_classes requires a non-None 'scope' argument, so that mappers can automatically locate a Session already in progress.")
+
+ class Sess(Session):
+ def __init__(self, **local_kwargs):
+ local_kwargs.setdefault('bind', bind)
+ local_kwargs.setdefault('autoflush', autoflush)
+ local_kwargs.setdefault('transactional', transactional)
+ for k in kwargs:
+ local_kwargs.setdefault(k, kwargs[k])
+ super(Sess, self).__init__(**local_kwargs)
+
+ if scope=="thread":
+ registry = util.ScopedRegistry(Sess, scopefunc=None)
+
+ if enhance_classes:
+ class SessionContextExt(MapperExtension):
+ def get_session(self):
+ return registry()
+
+ def instrument_class(self, mapper, class_):
+ class query(object):
+ def __getattr__(self, key):
+ return getattr(registry().query(class_), key)
+ def __call__(self):
+ return registry().query(class_)
+
+ if not hasattr(class_, 'query'):
+ class_.query = query()
+
+ def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
+ session = kwargs.pop('_sa_session', registry())
+ if not isinstance(oldinit, types.MethodType):
+ for key, value in kwargs.items():
+ #if validate:
+ # if not self.mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
+ # raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
+ setattr(instance, key, value)
+ session._save_impl(instance, entity_name=kwargs.pop('_sa_entity_name', None))
+ return EXT_CONTINUE
+
+ def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
+ object_session(instance).expunge(instance)
+ return EXT_CONTINUE
+
+ def dispose_class(self, mapper, class_):
+ if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
+ if class_.__init__._oldinit is not None:
+ class_.__init__ = class_.__init__._oldinit
+ else:
+ delattr(class_, '__init__')
+ if hasattr(class_, 'query'):
+ delattr(class_, 'query')
+
+ global_extensions.append(SessionContextExt())
+
+ default_scope=scope
+ class ScopedSess(Sess):
+ def __new__(cls, **kwargs):
+ if len(kwargs):
+ scope = kwargs.pop('scope', default_scope)
+ if scope is not None:
+ if registry.has():
+ raise exceptions.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
+ else:
+ sess = Sess(**kwargs)
+ registry.set(sess)
+ return sess
+ else:
+ return Sess(**kwargs)
+ else:
+ return registry()
+ def instrument(name):
+ def do(cls, *args, **kwargs):
+ return getattr(registry(), name)(*args, **kwargs)
+ return classmethod(do)
+ for meth in ('get', 'close', 'save', 'commit', 'update', 'flush', 'query', 'delete'):
+ setattr(ScopedSess, meth, instrument(meth))
+
+ return ScopedSess
+ elif scope is not None:
+ raise exceptions.ArgumentError("Unknown scope '%s'" % scope)
+ else:
+ return Sess
from sqlalchemy.orm import *
from sqlalchemy.orm.mapper import global_extensions
from sqlalchemy.orm import util as ormutil
-from sqlalchemy.ext.sessioncontext import SessionContext
-import sqlalchemy.ext.assignmapper as assignmapper
from testlib import *
from testlib.tables import *
from testlib import tables
class UnitOfWorkTest(AssertMixin):
def setUpAll(self):
- global ctx, assign_mapper
- ctx = SessionContext(Session)
- def assign_mapper(*args, **kwargs):
- return assignmapper.assign_mapper(ctx, *args, **kwargs)
- global_extensions.append(ctx.mapper_extension)
+ global Session
+ Session = sessionmaker(autoflush=True, transactional=True, enhance_classes=True, scope="thread")
def tearDownAll(self):
- global_extensions.remove(ctx.mapper_extension)
+ global_extensions[:] = []
def tearDown(self):
Session.close_all()
- ctx.current.close()
clear_mappers()
class HistoryTest(UnitOfWorkTest):
class VersioningTest(UnitOfWorkTest):
def setUpAll(self):
UnitOfWorkTest.setUpAll(self)
- ctx.current.close()
+ Session.close()
global version_table
version_table = Table('version_test', MetaData(testbase.db),
Column('id', Integer, Sequence('version_test_seq'), primary_key=True ),
version_table.delete().execute()
def testbasic(self):
- s = Session()
+ s = Session(scope=None)
class Foo(object):pass
- assign_mapper(Foo, version_table, version_id_col=version_table.c.version_id)
+ mapper(Foo, version_table, version_id_col=version_table.c.version_id)
f1 =Foo(value='f1', _sa_session=s)
f2 = Foo(value='f2', _sa_session=s)
s.commit()
def testversioncheck(self):
"""test that query.with_lockmode performs a 'version check' on an already loaded instance"""
- s1 = Session()
+ s1 = Session(scope=None)
class Foo(object):pass
- assign_mapper(Foo, version_table, version_id_col=version_table.c.version_id)
+ mapper(Foo, version_table, version_id_col=version_table.c.version_id)
f1s1 =Foo(value='f1', _sa_session=s1)
s1.commit()
s2 = Session()
"""test that query.with_lockmode works OK when the mapper has no version id col"""
s1 = Session()
class Foo(object):pass
- assign_mapper(Foo, version_table)
+ mapper(Foo, version_table)
f1s1 =Foo(value='f1', _sa_session=s1)
f1s1.version_id=0
s1.commit()
txt = u"\u0160\u0110\u0106\u010c\u017d"
t1 = Test(id=1, txt = txt)
self.assert_(t1.txt == txt)
- ctx.current.commit()
+ Session.commit()
self.assert_(t1.txt == txt)
def testrelation(self):
class Test(object):
t1 = Test(txt=txt)
t1.t2s.append(Test2())
t1.t2s.append(Test2())
- ctx.current.commit()
- ctx.current.close()
- t1 = ctx.current.query(Test).get_by(id=t1.id)
+ Session.commit()
+ Session.close()
+ t1 = Session.query(Test).get_by(id=t1.id)
assert len(t1.t2s) == 2
class MutableTypesTest(UnitOfWorkTest):
mapper(Foo, table)
f1 = Foo()
f1.data = pickleable.Bar(4,5)
- ctx.current.commit()
- ctx.current.close()
- f2 = ctx.current.query(Foo).get_by(id=f1.id)
+ Session.commit()
+ Session.close()
+ f2 = Session.query(Foo).get_by(id=f1.id)
assert f2.data == f1.data
f2.data.y = 19
- ctx.current.commit()
- ctx.current.close()
- f3 = ctx.current.query(Foo).get_by(id=f1.id)
+ Session.commit()
+ Session.close()
+ f3 = Session.query(Foo).get_by(id=f1.id)
print f2.data, f3.data
assert f3.data != f1.data
assert f3.data == pickleable.Bar(4, 19)
f1 = Foo()
f1.data = pickleable.Bar(4,5)
f1.value = unicode('hi')
- ctx.current.commit()
+ Session.commit()
def go():
- ctx.current.commit()
+ Session.commit()
self.assert_sql_count(testbase.db, go, 0)
f1.value = unicode('someothervalue')
- self.assert_sql(testbase.db, lambda: ctx.current.commit(), [
+ self.assert_sql(testbase.db, lambda: Session.commit(), [
(
"UPDATE mutabletest SET value=:value WHERE mutabletest.id = :mutabletest_id",
{'mutabletest_id': f1.id, 'value': u'someothervalue'}
])
f1.value = unicode('hi')
f1.data.x = 9
- self.assert_sql(testbase.db, lambda: ctx.current.commit(), [
+ self.assert_sql(testbase.db, lambda: Session.commit(), [
(
"UPDATE mutabletest SET data=:data, value=:value WHERE mutabletest.id = :mutabletest_id",
{'mutabletest_id': f1.id, 'value': u'hi', 'data':f1.data}
mapper(Foo, table)
f1 = Foo()
f1.data = pickleable.BarWithoutCompare(4,5)
- ctx.current.commit()
+ Session.commit()
def go():
- ctx.current.commit()
+ Session.commit()
self.assert_sql_count(testbase.db, go, 0)
- ctx.current.close()
+ Session.close()
- f2 = ctx.current.query(Foo).get_by(id=f1.id)
+ f2 = Session.query(Foo).get_by(id=f1.id)
def go():
- ctx.current.commit()
+ Session.commit()
self.assert_sql_count(testbase.db, go, 0)
f2.data.y = 19
def go():
- ctx.current.commit()
+ Session.commit()
self.assert_sql_count(testbase.db, go, 1)
- ctx.current.close()
- f3 = ctx.current.query(Foo).get_by(id=f1.id)
+ Session.close()
+ f3 = Session.query(Foo).get_by(id=f1.id)
print f2.data, f3.data
assert (f3.data.x, f3.data.y) == (4,19)
def go():
- ctx.current.commit()
+ Session.commit()
self.assert_sql_count(testbase.db, go, 0)
def testunicode(self):
mapper(Foo, table)
f1 = Foo()
f1.value = u'hi'
- ctx.current.commit()
- ctx.current.close()
- f1 = ctx.current.get(Foo, f1.id)
+ Session.commit()
+ Session.close()
+ f1 = Session.get(Foo, f1.id)
f1.value = u'hi'
def go():
- ctx.current.commit()
+ Session.commit()
self.assert_sql_count(testbase.db, go, 0)
e.name = 'entry1'
e.value = 'this is entry 1'
e.multi_rev = 2
- ctx.current.commit()
- ctx.current.close()
+ Session.commit()
+ Session.close()
e2 = Query(Entry).get((e.multi_id, 2))
self.assert_(e is not e2 and e._instance_key == e2._instance_key)
e.pk_col_1 = 'pk1'
e.pk_col_2 = 'pk1_related'
e.data = 'im the data'
- ctx.current.commit()
+ Session.commit()
def testkeypks(self):
import datetime
e.secondary = 'pk2'
e.assigned = datetime.date.today()
e.data = 'some more data'
- ctx.current.commit()
+ Session.commit()
def testpksimmutable(self):
class Entry(object):
e.multi_id=5
e.multi_rev=5
e.name='somename'
- ctx.current.commit()
+ Session.commit()
e.multi_rev=6
e.name = 'someothername'
try:
- ctx.current.commit()
+ Session.commit()
assert False
except exceptions.FlushError, fe:
assert str(fe) == "Can't change the identity of instance Entry@%s in session (existing identity: (%s, (5, 5), None); new identity: (%s, (5, 6), None))" % (hex(id(e)), repr(e.__class__), repr(e.__class__))
ps = PersonSite()
ps.site = 'asdf'
p.sites.append(ps)
- ctx.current.commit()
+ Session.commit()
assert people.count(people.c.person=='im the key').scalar() == peoplesites.count(peoplesites.c.person=='im the key').scalar() == 1
class PassiveDeletesTest(UnitOfWorkTest):
'children':relation(MyOtherClass, passive_deletes=True, cascade="all")
})
- sess = ctx.current
+ sess = Session
mc = MyClass()
mc.children.append(MyOtherClass())
mc.children.append(MyOtherClass())
def testinsert(self):
class Hoho(object):pass
- assign_mapper(Hoho, default_table)
+ mapper(Hoho, default_table)
h1 = Hoho(hoho=self.althohoval)
h2 = Hoho(counter=12)
h3 = Hoho(hoho=self.althohoval, counter=12)
h4 = Hoho()
h5 = Hoho(foober='im the new foober')
- ctx.current.commit()
+ Session.commit()
self.assert_(h1.hoho==self.althohoval)
self.assert_(h3.hoho==self.althohoval)
self.assert_(h5.foober=='im the new foober')
self.assert_sql_count(testbase.db, go, 0)
- ctx.current.close()
+ Session.close()
l = Query(Hoho).select()
def testinsertnopostfetch(self):
# populates the PassiveDefaults explicitly so there is no "post-update"
class Hoho(object):pass
- assign_mapper(Hoho, default_table)
+ mapper(Hoho, default_table)
h1 = Hoho(hoho="15", counter="15")
- ctx.current.commit()
+ Session.commit()
def go():
self.assert_(h1.hoho=="15")
self.assert_(h1.counter=="15")
def testupdate(self):
class Hoho(object):pass
- assign_mapper(Hoho, default_table)
+ mapper(Hoho, default_table)
h1 = Hoho()
- ctx.current.commit()
+ Session.commit()
self.assert_(h1.foober == 'im foober')
h1.counter = 19
- ctx.current.commit()
+ Session.commit()
self.assert_(h1.foober == 'im the update')
class OneToManyTest(UnitOfWorkTest):
a2.email_address = 'lala@test.org'
u.addresses.append(a2)
print repr(u.addresses)
- ctx.current.commit()
+ Session.commit()
usertable = users.select(users.c.user_id.in_(u.user_id)).execute().fetchall()
self.assertEqual(usertable[0].values(), [u.user_id, 'one2manytester'])
a2.email_address = 'somethingnew@foo.com'
- ctx.current.commit()
+ Session.commit()
addresstable = addresses.select(addresses.c.address_id == addressid).execute().fetchall()
self.assertEqual(addresstable[0].values(), [addressid, userid, 'somethingnew@foo.com'])
a3 = Address()
a3.email_address = 'emailaddress3'
- ctx.current.commit()
+ Session.commit()
# modify user2 directly, append an address to user1.
# upon commit, user2 should be updated, user1 should not
u2.user_name = 'user2modified'
u1.addresses.append(a3)
del u1.addresses[0]
- self.assert_sql(testbase.db, lambda: ctx.current.commit(),
+ self.assert_sql(testbase.db, lambda: Session.commit(),
[
(
"UPDATE users SET user_name=:user_name WHERE users.user_id = :users_user_id",
a = Address()
a.email_address = 'address1'
u1.addresses.append(a)
- ctx.current.commit()
+ Session.commit()
del u1.addresses[0]
u2.addresses.append(a)
- ctx.current.delete(u1)
- ctx.current.commit()
- ctx.current.close()
- u2 = ctx.current.get(User, u2.user_id)
+ Session.delete(u1)
+ Session.commit()
+ Session.close()
+ u2 = Session.get(User, u2.user_id)
assert len(u2.addresses) == 1
def testchildmove_2(self):
a = Address()
a.email_address = 'address1'
u1.addresses.append(a)
- ctx.current.commit()
+ Session.commit()
del u1.addresses[0]
u2.addresses.append(a)
- ctx.current.commit()
- ctx.current.close()
- u2 = ctx.current.get(User, u2.user_id)
+ Session.commit()
+ Session.close()
+ u2 = Session.get(User, u2.user_id)
assert len(u2.addresses) == 1
def testo2mdeleteparent(self):
u.user_name = 'one2onetester'
u.address = a
u.address.email_address = 'myonlyaddress@foo.com'
- ctx.current.commit()
- ctx.current.delete(u)
- ctx.current.commit()
- self.assert_(a.address_id is not None and a.user_id is None and not ctx.current.identity_map.has_key(u._instance_key) and ctx.current.identity_map.has_key(a._instance_key))
+ Session.commit()
+ Session.delete(u)
+ Session.commit()
+ self.assert_(a.address_id is not None and a.user_id is None and not Session().identity_map.has_key(u._instance_key) and Session().identity_map.has_key(a._instance_key))
def testonetoone(self):
m = mapper(User, users, properties = dict(
u.user_name = 'one2onetester'
u.address = Address()
u.address.email_address = 'myonlyaddress@foo.com'
- ctx.current.commit()
+ Session.commit()
u.user_name = 'imnew'
- ctx.current.commit()
+ Session.commit()
u.address.email_address = 'imnew@foo.com'
- ctx.current.commit()
+ Session.commit()
def testbidirectional(self):
m1 = mapper(User, users)
a = Address()
a.email_address = 'testaddress'
a.user = u
- ctx.current.commit()
+ Session.commit()
print repr(u.addresses)
x = False
try:
if x:
self.assert_(False, "User addresses element should be scalar based")
- ctx.current.delete(u)
- ctx.current.commit()
+ Session.delete(u)
+ Session.commit()
def testdoublerelation(self):
m2 = mapper(Address, addresses)
u.boston_addresses.append(a)
u.newyork_addresses.append(b)
- ctx.current.commit()
+ Session.commit()
class SaveTest(UnitOfWorkTest):
u2 = User()
u2.user_name = 'savetester2'
- ctx.current.save(u)
+ Session.save(u)
- ctx.current.flush([u])
- ctx.current.commit()
+ Session.flush([u])
+ Session.commit()
# assert the first one retreives the same from the identity map
- nu = ctx.current.get(m, u.user_id)
+ nu = Session.get(m, u.user_id)
print "U: " + repr(u) + "NU: " + repr(nu)
self.assert_(u is nu)
# clear out the identity map, so next get forces a SELECT
- ctx.current.close()
+ Session.close()
# check it again, identity should be different but ids the same
- nu = ctx.current.get(m, u.user_id)
+ nu = Session.get(m, u.user_id)
self.assert_(u is not nu and u.user_id == nu.user_id and nu.user_name == 'savetester')
# change first users name and save
- ctx.current.update(u)
+ Session.update(u)
u.user_name = 'modifiedname'
- assert u in ctx.current.dirty
- ctx.current.commit()
+ assert u in Session().dirty
+ Session.commit()
# select both
- #ctx.current.close()
+ #Session.close()
userlist = Query(m).select(users.c.user_id.in_(u.user_id, u2.user_id), order_by=[users.c.user_name])
print repr(u.user_id), repr(userlist[0].user_id), repr(userlist[0].user_name)
self.assert_(u.user_id == userlist[0].user_id and userlist[0].user_name == 'modifiedname')
u.addresses.append(Address())
u.addresses.append(Address())
u.addresses.append(Address())
- ctx.current.commit()
- ctx.current.close()
- ulist = ctx.current.query(m1).select()
+ Session.commit()
+ Session.close()
+ ulist = Session.query(m1).select()
u1 = ulist[0]
u1.user_name = 'newname'
- ctx.current.commit()
+ Session.commit()
self.assert_(len(u1.addresses) == 4)
def testinherits(self):
)
au = AddressUser()
- ctx.current.commit()
- ctx.current.close()
- l = ctx.current.query(AddressUser).selectone()
+ Session.commit()
+ Session.close()
+ l = Session.query(AddressUser).selectone()
self.assert_(l.user_id == au.user_id and l.address_id == au.address_id)
def testdeferred(self):
})
u = User()
u.user_id=42
- ctx.current.commit()
+ Session.commit()
def test_dont_update_blanks(self):
mapper(User, users)
u = User()
u.user_name = ""
- ctx.current.commit()
- ctx.current.close()
- u = ctx.current.query(User).get(u.user_id)
+ Session.commit()
+ Session.close()
+ u = Session.query(User).get(u.user_id)
u.user_name = ""
def go():
- ctx.current.commit()
+ Session.commit()
self.assert_sql_count(testbase.db, go, 0)
def testmultitable(self):
u.user_name = 'multitester'
u.email = 'multi@test.org'
- ctx.current.commit()
+ Session.commit()
id = m.primary_key_from_instance(u)
- ctx.current.close()
+ Session.close()
- u = ctx.current.get(User, id)
+ u = Session.get(User, id)
assert u.user_name == 'multitester'
usertable = users.select(users.c.user_id.in_(u.foo_id)).execute().fetchall()
u.email = 'lala@hey.com'
u.user_name = 'imnew'
- ctx.current.commit()
+ Session.commit()
usertable = users.select(users.c.user_id.in_(u.foo_id)).execute().fetchall()
self.assertEqual(usertable[0].values(), [u.foo_id, 'imnew'])
addresstable = addresses.select(addresses.c.address_id.in_(u.address_id)).execute().fetchall()
self.assertEqual(addresstable[0].values(), [u.address_id, u.foo_id, 'lala@hey.com'])
- ctx.current.close()
- u = ctx.current.get(User, id)
+ Session.close()
+ u = Session.get(User, id)
assert u.user_name == 'imnew'
def testhistoryget(self):
u = User()
u.addresses.append(Address())
u.addresses.append(Address())
- ctx.current.commit()
- ctx.current.close()
- u = ctx.current.query(User).get(u.user_id)
- ctx.current.delete(u)
- ctx.current.commit()
+ Session.commit()
+ Session.close()
+ u = Session.query(User).get(u.user_id)
+ Session.delete(u)
+ Session.commit()
assert users.count().scalar() == 0
assert addresses.count().scalar() == 0
u1.username = 'user1'
u2 = User()
u2.username = 'user2'
- ctx.current.commit()
+ Session.commit()
clear_mappers()
u2 = User()
u2.username = 'user2'
try:
- ctx.current.commit()
+ Session.commit()
assert False
except AssertionError:
assert True
a.user.user_name = elem['user_name']
objects.append(a)
- ctx.current.commit()
+ Session.commit()
objects[2].email_address = 'imnew@foo.bar'
objects[3].user = User()
objects[3].user.user_name = 'imnewlyadded'
- self.assert_sql(testbase.db, lambda: ctx.current.commit(), [
+ self.assert_sql(testbase.db, lambda: Session.commit(), [
(
"INSERT INTO users (user_name) VALUES (:user_name)",
{'user_name': 'imnewlyadded'}
u1.user_name='user1'
a1.user = u1
- ctx.current.commit()
- ctx.current.close()
- a1 = ctx.current.query(Address).get(a1.address_id)
- u1 = ctx.current.query(User).get(u1.user_id)
+ Session.commit()
+ Session.close()
+ a1 = Session.query(Address).get(a1.address_id)
+ u1 = Session.query(User).get(u1.user_id)
assert a1.user is u1
a1.user = None
- ctx.current.commit()
- ctx.current.close()
- a1 = ctx.current.query(Address).get(a1.address_id)
- u1 = ctx.current.query(User).get(u1.user_id)
+ Session.commit()
+ Session.close()
+ a1 = Session.query(Address).get(a1.address_id)
+ u1 = Session.query(User).get(u1.user_id)
assert a1.user is None
def testmanytoone_2(self):
u1.user_name='user1'
a1.user = u1
- ctx.current.commit()
- ctx.current.close()
- a1 = ctx.current.query(Address).get(a1.address_id)
- a2 = ctx.current.query(Address).get(a2.address_id)
- u1 = ctx.current.query(User).get(u1.user_id)
+ Session.commit()
+ Session.close()
+ a1 = Session.query(Address).get(a1.address_id)
+ a2 = Session.query(Address).get(a2.address_id)
+ u1 = Session.query(User).get(u1.user_id)
assert a1.user is u1
a1.user = None
a2.user = u1
- ctx.current.commit()
- ctx.current.close()
- a1 = ctx.current.query(Address).get(a1.address_id)
- a2 = ctx.current.query(Address).get(a2.address_id)
- u1 = ctx.current.query(User).get(u1.user_id)
+ Session.commit()
+ Session.close()
+ a1 = Session.query(Address).get(a1.address_id)
+ a2 = Session.query(Address).get(a2.address_id)
+ u1 = Session.query(User).get(u1.user_id)
assert a1.user is None
assert a2.user is u1
u2.user_name='user2'
a1.user = u1
- ctx.current.commit()
- ctx.current.close()
- a1 = ctx.current.query(Address).get(a1.address_id)
- u1 = ctx.current.query(User).get(u1.user_id)
- u2 = ctx.current.query(User).get(u2.user_id)
+ Session.commit()
+ Session.close()
+ a1 = Session.query(Address).get(a1.address_id)
+ u1 = Session.query(User).get(u1.user_id)
+ u2 = Session.query(User).get(u2.user_id)
assert a1.user is u1
a1.user = u2
- ctx.current.commit()
- ctx.current.close()
- a1 = ctx.current.query(Address).get(a1.address_id)
- u1 = ctx.current.query(User).get(u1.user_id)
- u2 = ctx.current.query(User).get(u2.user_id)
+ Session.commit()
+ Session.close()
+ a1 = Session.query(Address).get(a1.address_id)
+ u1 = Session.query(User).get(u1.user_id)
+ u2 = Session.query(User).get(u2.user_id)
assert a1.user is u2
class ManyToManyTest(UnitOfWorkTest):
item.item_name = elem['item_name']
item.keywords = []
if len(elem['keywords'][1]):
- klist = ctx.current.query(keywordmapper).select(keywords.c.name.in_(*[e['name'] for e in elem['keywords'][1]]))
+ klist = Session.query(keywordmapper).select(keywords.c.name.in_(*[e['name'] for e in elem['keywords'][1]]))
else:
klist = []
khash = {}
k.name = kname
item.keywords.append(k)
- ctx.current.commit()
+ Session.commit()
- l = ctx.current.query(m).select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name])
+ l = Session.query(m).select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name])
self.assert_result(l, *data)
objects[4].item_name = 'item4updated'
k = Keyword()
k.name = 'yellow'
objects[5].keywords.append(k)
- self.assert_sql(testbase.db, lambda:ctx.current.commit(), [
+ self.assert_sql(testbase.db, lambda:Session.commit(), [
{
"UPDATE items SET item_name=:item_name WHERE items.item_id = :items_item_id":
{'item_name': 'item4updated', 'items_item_id': objects[4].item_id}
objects[2].keywords.append(k)
dkid = objects[5].keywords[1].keyword_id
del objects[5].keywords[1]
- self.assert_sql(testbase.db, lambda:ctx.current.commit(), [
+ self.assert_sql(testbase.db, lambda:Session.commit(), [
(
"DELETE FROM itemkeywords WHERE itemkeywords.item_id = :item_id AND itemkeywords.keyword_id = :keyword_id",
[{'item_id': objects[5].item_id, 'keyword_id': dkid}]
)
])
- ctx.current.delete(objects[3])
- ctx.current.commit()
+ Session.delete(objects[3])
+ Session.commit()
def testmanytomanyremove(self):
"""tests that setting a list-based attribute to '[]' properly affects the history and allows
k2 = Keyword()
i.keywords.append(k1)
i.keywords.append(k2)
- ctx.current.commit()
+ Session.commit()
assert itemkeywords.count().scalar() == 2
i.keywords = []
- ctx.current.commit()
+ Session.commit()
assert itemkeywords.count().scalar() == 0
def testscalar(self):
))
i = Item()
- ctx.current.commit()
- ctx.current.delete(i)
- ctx.current.commit()
+ Session.commit()
+ Session.delete(i)
+ Session.commit()
item.keywords.append(k1)
item.keywords.append(k2)
item.keywords.append(k3)
- ctx.current.commit()
+ Session.commit()
item.keywords = []
item.keywords.append(k1)
item.keywords.append(k2)
- ctx.current.commit()
+ Session.commit()
- ctx.current.close()
- item = ctx.current.query(Item).get(item.item_id)
+ Session.close()
+ item = Session.query(Item).get(item.item_id)
print [k1, k2]
print item.keywords
assert item.keywords == [k1, k2]
ik.keyword = k
item.keywords.append(ik)
- ctx.current.commit()
- ctx.current.close()
+ Session.commit()
+ Session.close()
l = Query(m).select(items.c.item_name.in_(*[e['item_name'] for e in data[1:]]), order_by=[items.c.item_name])
self.assert_result(l, *data)
k = KeywordUser()
k.user_name = 'keyworduser'
k.keyword_name = 'a keyword'
- ctx.current.commit()
+ Session.commit()
id = (k.user_id, k.keyword_id)
- ctx.current.close()
- k = ctx.current.query(KeywordUser).get(id)
+ Session.close()
+ k = Session.query(KeywordUser).get(id)
assert k.user_name == 'keyworduser'
assert k.keyword_name == 'a keyword'
class SaveTest2(UnitOfWorkTest):
def setUp(self):
- ctx.current.close()
+ Session.close()
clear_mappers()
global meta, users, addresses
meta = MetaData(testbase.db)
a.user = User()
a.user.user_name = elem['user_name']
objects.append(a)
- self.assert_sql(testbase.db, lambda: ctx.current.commit(), [
+ self.assert_sql(testbase.db, lambda: Session.commit(), [
(
"INSERT INTO users (user_name) VALUES (:user_name)",
{'user_name': 'thesub'}
k2 = Keyword()
i.keywords.append(k1)
i.keywords.append(k2)
- ctx.current.commit()
+ Session.commit()
assert t2.count().scalar() == 2
i.keywords = []
print i.keywords
- ctx.current.commit()
+ Session.commit()
assert t2.count().scalar() == 0