From: Mike Bayer Date: Mon, 16 Jul 2007 21:01:23 +0000 (+0000) Subject: fix to SessionTransaction so it holds onto a Connection properly X-Git-Tag: rel_0_4_6~103 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=56fe7f8cb79cd5b90c43153e8636138d8e1d8e10;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fix to SessionTransaction so it holds onto a Connection properly --- diff --git a/CHANGES b/CHANGES index 83df609f10..d777d45dfc 100644 --- a/CHANGES +++ b/CHANGES @@ -71,14 +71,25 @@ - added undefer_group() MapperOption, sets a set of "deferred" columns joined by a "group" to load as "undeferred". - + + - session enhancements/fixes: + - session can be bound to Connections + + - rewrite of the "deterministic alias name" logic to be part of the + SQL layer, produces much simpler alias and label names more in the + style of Hibernate + - sql - - added context manager (with statement) support for transactions - - added support for two phase commit, works with mysql and postgres so far. - - added a subtransaction implementation that uses savepoints. - - added support for savepoints. - - DynamicMetaData has been renamed to ThreadLocalMetaData - - BoundMetaData has been removed- regular MetaData is equivalent + - transactions: + - added context manager (with statement) support for transactions + - added support for two phase commit, works with mysql and postgres so far. + - added a subtransaction implementation that uses savepoints. + - added support for savepoints. + - MetaData: + - DynamicMetaData has been renamed to ThreadLocalMetaData + - BoundMetaData has been removed- regular MetaData is equivalent + - "anonymous" alias and label names are now generated at SQL compilation + time in a completely deterministic fashion...no more random hex IDs - significant architectural overhaul to SQL elements (ClauseElement). all elements share a common "mutability" framework which allows a consistent approach to in-place modifications of elements as well as diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index a4e2287ded..d482286045 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -4,7 +4,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import util, exceptions, sql +from sqlalchemy import util, exceptions, sql, engine from sqlalchemy.orm import unitofwork, query from sqlalchemy.orm.mapper import object_mapper as _object_mapper from sqlalchemy.orm.mapper import class_mapper as _class_mapper @@ -30,8 +30,6 @@ class SessionTransaction(object): def connection(self, mapper_or_class, entity_name=None): if isinstance(mapper_or_class, type): mapper_or_class = _class_mapper(mapper_or_class, entity_name=entity_name) - if self.parent is not None: - return self.parent.connection(mapper_or_class) engine = self.session.get_bind(mapper_or_class) return self.get_or_add(engine) @@ -39,28 +37,36 @@ class SessionTransaction(object): return SessionTransaction(self.session, self) def add(self, bind): + if self.parent is not None: + return self.parent.add(bind) + if self.connections.has_key(bind.engine): raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine") return self.get_or_add(bind) def get_or_add(self, bind): - # we reference the 'engine' attribute on the given object, which in the case of - # Connection, ProxyEngine, Engine, whatever, should return the original - # "Engine" object that is handling the connection. - if self.connections.has_key(bind.engine): - return self.connections[bind.engine][0] - e = bind.engine - c = bind.contextual_connect() - if not self.connections.has_key(e): - self.connections[e] = (c, c.begin(), c is not bind) - return self.connections[e][0] + if self.parent is not None: + return self.parent.get_or_add(bind) + + if self.connections.has_key(bind): + return self.connections[bind][0] + + if not isinstance(bind, engine.Connection): + e = bind + c = bind.contextual_connect() + else: + e = bind.engine + c = bind + + self.connections[bind] = self.connections[e] = (c, c.begin(), c is not bind) + return self.connections[bind][0] def commit(self): if self.parent is not None: return if self.autoflush: self.session.flush() - for t in self.connections.values(): + for t in util.Set(self.connections.values()): t[1].commit() self.close() diff --git a/test/orm/session.py b/test/orm/session.py index e67037b0af..9fa7e7f821 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -84,6 +84,21 @@ class SessionTest(AssertMixin): transaction.rollback() assert len(sess.query(User).select()) == 0 + def test_bound_connection(self): + class User(object):pass + mapper(User, users) + c = testbase.db.connect() + sess = create_session(bind=c) + transaction = sess.create_transaction() + trans2 = sess.create_transaction() + u = User() + sess.save(u) + sess.flush() + assert transaction.get_or_add(testbase.db) is trans2.get_or_add(testbase.db) #is transaction.get_or_add(c) is trans2.get_or_add(c) is c + trans2.commit() + transaction.rollback() + assert len(sess.query(User).select()) == 0 + def test_close_two(self): c = testbase.db.connect() try: