return self.execute(object, *multiparams, **params).scalar()
def compiler(self, statement, parameters, **kwargs):
- return self.dialect.compiler(statement, parameters, engine=self.engine, **kwargs)
+ return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
def execute(self, object, *multiparams, **params):
for c in type(object).__mro__:
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-from sqlalchemy import util, exceptions, sql
+from sqlalchemy import util, exceptions, sql, engine
from sqlalchemy.orm import unitofwork, query
from sqlalchemy.orm.mapper import object_mapper as _object_mapper
from sqlalchemy.orm.mapper import class_mapper as _class_mapper
def connection(self, mapper_or_class, entity_name=None):
if isinstance(mapper_or_class, type):
mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name)
- if self.parent is not None:
- return self.parent.connection(mapper_or_class)
engine = self.session.get_bind(mapper_or_class)
return self.get_or_add(engine)
return SessionTransaction(self.session, self)
def add(self, bind):
+ if self.parent is not None:
+ return self.parent.add(bind)
+
if self.connections.has_key(bind.engine):
raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
return self.get_or_add(bind)
def get_or_add(self, bind):
- # we reference the 'engine' attribute on the given object, which in the case of
- # Connection, ProxyEngine, Engine, whatever, should return the original
- # "Engine" object that is handling the connection.
- if self.connections.has_key(bind.engine):
- return self.connections[bind.engine][0]
- e = bind.engine
- c = bind.contextual_connect()
- if not self.connections.has_key(e):
- self.connections[e] = (c, c.begin(), c is not bind)
- return self.connections[e][0]
+ if self.parent is not None:
+ return self.parent.get_or_add(bind)
+
+ if self.connections.has_key(bind):
+ return self.connections[bind][0]
+
+ if not isinstance(bind, engine.Connection):
+ e = bind
+ c = bind.contextual_connect()
+ else:
+ e = bind.engine
+ c = bind
+
+ self.connections[bind] = self.connections[e] = (c, c.begin(), c is not bind)
+ return self.connections[bind][0]
def commit(self):
if self.parent is not None:
return
if self.autoflush:
self.session.flush()
- for t in self.connections.values():
+ for t in util.Set(self.connections.values()):
t[1].commit()
self.close()
table.drop()
assert not table.exists()
+ def test_implicit_execution(self):
+ metadata = MetaData()
+ table = Table('test_table', metadata,
+ Column('foo', Integer))
+ conn = testbase.db.connect()
+ metadata.create_all(bind=conn)
+ try:
+ trans = conn.begin()
+ metadata.bind = conn
+ t = table.insert()
+ assert t.bind is conn
+ table.insert().execute(foo=5)
+ table.insert().execute(foo=6)
+ table.insert().execute(foo=7)
+ trans.rollback()
+ metadata.bind = None
+ assert testbase.db.execute("select count(1) from test_table").scalar() == 0
+ finally:
+ metadata.drop_all(bind=conn)
+
def test_clauseelement(self):
metadata = MetaData()
transaction.rollback()
assert len(sess.query(User).select()) == 0
+ def test_bound_connection(self):
+ class User(object):pass
+ mapper(User, users)
+ c = testbase.db.connect()
+ sess = create_session(bind=c)
+ transaction = sess.create_transaction()
+ trans2 = sess.create_transaction()
+ u = User()
+ sess.save(u)
+ sess.flush()
+ assert transaction.get_or_add(testbase.db) is trans2.get_or_add(testbase.db) #is transaction.get_or_add(c) is trans2.get_or_add(c) is c
+ trans2.commit()
+ transaction.rollback()
+ assert len(sess.query(User).select()) == 0
+
def test_close_two(self):
c = testbase.db.connect()
try: