]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- added Query.with_session() method, switches
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Apr 2011 22:38:01 +0000 (18:38 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Apr 2011 22:38:01 +0000 (18:38 -0400)
Query to use a different session.

- horizontal shard query should use execution
options per connection as per [ticket:2131]

CHANGES
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/orm/query.py
test/lib/requires.py
test/orm/test_query.py

diff --git a/CHANGES b/CHANGES
index 87c20d7a05e6f2c5db301f23c5d05f958a986e56..abdb9299e9c67f473a0661493f047818788612d2 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -21,6 +21,9 @@ CHANGES
     a deprecation warning in 0.6.8.  
     [ticket:2144]
 
+  - added Query.with_session() method, switches 
+    Query to use a different session.
+
 - sql
   - Some improvements to error handling inside
     of the execute procedure to ensure auto-close
index dfd471c78cb4a14203ca1b7efaf5e7274fc359c0..6aafb227455d2e93188e6c1179c8e80b860e56cf 100644 (file)
@@ -40,20 +40,21 @@ class ShardedQuery(Query):
         return q
 
     def _execute_and_instances(self, context):
-        if self._shard_id is not None:
-            context.attributes['shard_id'] = self._shard_id
-            result = self.session.connection(
+        def iter_for_shard(shard_id):
+            context.attributes['shard_id'] = shard_id
+            result = self._connection_from_session(
                             mapper=self._mapper_zero(),
-                            shard_id=self._shard_id).execute(context.statement, self._params)
+                            shard_id=shard_id).execute(
+                                                context.statement, 
+                                                self._params)
             return self.instances(result, context)
+
+        if self._shard_id is not None:
+            return iter_for_shard(self._shard_id)
         else:
             partial = []
             for shard_id in self.query_chooser(self):
-                context.attributes['shard_id'] = shard_id
-                result = self.session.connection(
-                            mapper=self._mapper_zero(),
-                            shard_id=shard_id).execute(context.statement, self._params)
-                partial = partial + list(self.instances(result, context))
+                partial.extend(iter_for_shard(shard_id))
 
             # if some kind of in memory 'sorting' 
             # were done, this is where it would happen
index ef42e0d3ab401fc5b85cdb1d0aece72f3894bdba..75fd5870e06c4a3479f2bb249d8d0a13b53859d1 100644 (file)
@@ -773,6 +773,14 @@ class Query(object):
         m = _MapperEntity(self, entity)
         self._setup_aliasizers([m])
 
+    @_generative()
+    def with_session(self, session):
+        """Return a :class:`Query` that will use the given :class:`.Session`.
+
+        """
+
+        self.session = session
+
     def from_self(self, *entities):
         """return a Query that selects from this Query's 
         SELECT statement.
@@ -1766,13 +1774,18 @@ class Query(object):
             self.session._autoflush()
         return self._execute_and_instances(context)
 
-    def _execute_and_instances(self, querycontext):
+    def _connection_from_session(self, **kw):
         conn = self.session.connection(
+                        **kw)
+        if self._execution_options:
+            conn = conn.execution_options(**self._execution_options)
+        return conn
+
+    def _execute_and_instances(self, querycontext):
+        conn = self._connection_from_session(
                         mapper = self._mapper_zero_or_none(),
                         clause = querycontext.statement,
                         close_with_result=True)
-        if self._execution_options:
-            conn = conn.execution_options(**self._execution_options)
 
         result = conn.execute(querycontext.statement, self._params)
         return self.instances(result, querycontext)
index 1be308fe7e89285406270f2b014807651ac11552..5f3eb1c9aca37be4e527c6a2f445ce9c3abd78e3 100644 (file)
@@ -287,12 +287,17 @@ def cextensions(fn):
     )
 
 def dbapi_lastrowid(fn):
-    return _chain_decorators_on(
-        fn,
-        fails_on_everything_except('mysql+mysqldb', 'mysql+oursql',
-                                   'sqlite+pysqlite', 'mysql+pymysql'),
-        fails_if(lambda: util.pypy),
-    )
+    if util.pypy:
+        return _chain_decorators_on(
+            fn,
+            fails_if(lambda:True)
+        )
+    else:
+        return _chain_decorators_on(
+            fn,
+            fails_on_everything_except('mysql+mysqldb', 'mysql+oursql',
+                                       'sqlite+pysqlite', 'mysql+pymysql'),
+        )
 
 def sane_multi_rowcount(fn):
     return _chain_decorators_on(
index a87e1398a59626512f404dd79c4f96624a2be93d..3a60e878d3bc0776203672242f91d7838c54389c 100644 (file)
@@ -66,6 +66,19 @@ class QueryTest(_fixtures.FixtureTest):
 
         configure_mappers()
 
+class MiscTest(QueryTest):
+    run_create_tables = None
+    run_inserts = None
+
+    def test_with_session(self):
+        User = self.classes.User
+        s1 = Session()
+        s2 = Session()
+        q1 = s1.query(User)
+        q2 = q1.with_session(s2)
+        assert q2.session is s2
+        assert q1.session is s1
+
 class RowTupleTest(QueryTest):
     run_setup_mappers = None