From: Mike Bayer Date: Sat, 23 Apr 2011 22:38:01 +0000 (-0400) Subject: - added Query.with_session() method, switches X-Git-Tag: rel_0_7_0~35 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=02fa8bacaa69f1a4b246bed0f0b89998e33ae847;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - added Query.with_session() method, switches Query to use a different session. - horizontal shard query should use execution options per connection as per [ticket:2131] --- diff --git a/CHANGES b/CHANGES index 87c20d7a05..abdb9299e9 100644 --- a/CHANGES +++ b/CHANGES @@ -21,6 +21,9 @@ CHANGES a deprecation warning in 0.6.8. [ticket:2144] + - added Query.with_session() method, switches + Query to use a different session. + - sql - Some improvements to error handling inside of the execute procedure to ensure auto-close diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index dfd471c78c..6aafb22745 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -40,20 +40,21 @@ class ShardedQuery(Query): return q def _execute_and_instances(self, context): - if self._shard_id is not None: - context.attributes['shard_id'] = self._shard_id - result = self.session.connection( + def iter_for_shard(shard_id): + context.attributes['shard_id'] = shard_id + result = self._connection_from_session( mapper=self._mapper_zero(), - shard_id=self._shard_id).execute(context.statement, self._params) + shard_id=shard_id).execute( + context.statement, + self._params) return self.instances(result, context) + + if self._shard_id is not None: + return iter_for_shard(self._shard_id) else: partial = [] for shard_id in self.query_chooser(self): - context.attributes['shard_id'] = shard_id - result = self.session.connection( - mapper=self._mapper_zero(), - shard_id=shard_id).execute(context.statement, self._params) - partial = partial + list(self.instances(result, context)) + partial.extend(iter_for_shard(shard_id)) # if some kind of in memory 'sorting' # were done, this is where it would happen diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index ef42e0d3ab..75fd5870e0 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -773,6 +773,14 @@ class Query(object): m = _MapperEntity(self, entity) self._setup_aliasizers([m]) + @_generative() + def with_session(self, session): + """Return a :class:`Query` that will use the given :class:`.Session`. + + """ + + self.session = session + def from_self(self, *entities): """return a Query that selects from this Query's SELECT statement. @@ -1766,13 +1774,18 @@ class Query(object): self.session._autoflush() return self._execute_and_instances(context) - def _execute_and_instances(self, querycontext): + def _connection_from_session(self, **kw): conn = self.session.connection( + **kw) + if self._execution_options: + conn = conn.execution_options(**self._execution_options) + return conn + + def _execute_and_instances(self, querycontext): + conn = self._connection_from_session( mapper = self._mapper_zero_or_none(), clause = querycontext.statement, close_with_result=True) - if self._execution_options: - conn = conn.execution_options(**self._execution_options) result = conn.execute(querycontext.statement, self._params) return self.instances(result, querycontext) diff --git a/test/lib/requires.py b/test/lib/requires.py index 1be308fe7e..5f3eb1c9ac 100644 --- a/test/lib/requires.py +++ b/test/lib/requires.py @@ -287,12 +287,17 @@ def cextensions(fn): ) def dbapi_lastrowid(fn): - return _chain_decorators_on( - fn, - fails_on_everything_except('mysql+mysqldb', 'mysql+oursql', - 'sqlite+pysqlite', 'mysql+pymysql'), - fails_if(lambda: util.pypy), - ) + if util.pypy: + return _chain_decorators_on( + fn, + fails_if(lambda:True) + ) + else: + return _chain_decorators_on( + fn, + fails_on_everything_except('mysql+mysqldb', 'mysql+oursql', + 'sqlite+pysqlite', 'mysql+pymysql'), + ) def sane_multi_rowcount(fn): return _chain_decorators_on( diff --git a/test/orm/test_query.py b/test/orm/test_query.py index a87e1398a5..3a60e878d3 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -66,6 +66,19 @@ class QueryTest(_fixtures.FixtureTest): configure_mappers() +class MiscTest(QueryTest): + run_create_tables = None + run_inserts = None + + def test_with_session(self): + User = self.classes.User + s1 = Session() + s2 = Session() + q1 = s1.query(User) + q2 = q1.with_session(s2) + assert q2.session is s2 + assert q1.session is s1 + class RowTupleTest(QueryTest): run_setup_mappers = None