]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- rename _select_wraps
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Mar 2015 16:33:38 +0000 (12:33 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Mar 2015 16:33:38 +0000 (12:33 -0400)
- replace force_result_map with a mini-API for nested result sets, add
coverage

lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/sql/compiler.py
test/sql/test_compiler.py

index a35ab80d320df76e0d9b5dc5b66d1963092d9260..3d531fe721f45a7788c738e1e2b9c4d0e6769486 100644 (file)
@@ -1031,7 +1031,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
             _order_by_clauses = select._order_by_clause.clauses
             limit_clause = select._limit_clause
             offset_clause = select._offset_clause
-            kwargs['_select_wraps'] = select
+            kwargs['select_wraps_for'] = select
             select = select._generate()
             select._mssql_visit = True
             select = select.column(
index 9ec84d268f7393f15e4dc70f932ac113a458292d..33bb443eef3f8170136a7082e17243f6ba7eea6c 100644 (file)
@@ -737,8 +737,7 @@ class OracleCompiler(compiler.SQLCompiler):
                 # Outer select and "ROWNUM as ora_rn" can be dropped if
                 # limit=0
 
-                # TODO: use annotations instead of clone + attr set ?
-                kwargs['_select_wraps'] = select
+                kwargs['select_wraps_for'] = select
                 select = select._generate()
                 select._oracle_visit = True
 
index a3a247ac03a0b737e95268cfb14a2177e8b11fec..64fc365b96cc5881fafd2f4e7f53e9cc0c43ffd2 100644 (file)
@@ -23,6 +23,7 @@ To generate user-defined SQL strings, see
 
 """
 
+import contextlib
 import re
 from . import schema, sqltypes, operators, functions, visitors, \
     elements, selectable, crud
@@ -407,6 +408,26 @@ class SQLCompiler(Compiled):
         if self.positional:
             self.cte_positional = {}
 
+    @contextlib.contextmanager
+    def _nested_result(self):
+        """special API to support the use case of 'nested result sets'"""
+        result_columns, ordered_columns = (
+            self._result_columns, self._ordered_columns)
+        self._result_columns, self._ordered_columns = [], False
+
+        try:
+            if self.stack:
+                entry = self.stack[-1]
+                entry['need_result_map_for_nested'] = True
+            else:
+                entry = None
+            yield self._result_columns, self._ordered_columns
+        finally:
+            if entry:
+                entry.pop('need_result_map_for_nested')
+            self._result_columns, self._ordered_columns = (
+                result_columns, ordered_columns)
+
     def _apply_numbered_params(self):
         poscount = itertools.count(1)
         self.string = re.sub(
@@ -673,17 +694,18 @@ class SQLCompiler(Compiled):
         )
 
     def visit_text_as_from(self, taf,
-                           compound_index=None, force_result_map=False,
+                           compound_index=None,
                            asfrom=False,
                            parens=True, **kw):
 
         toplevel = not self.stack
         entry = self._default_stack_entry if toplevel else self.stack[-1]
 
-        populate_result_map = force_result_map or \
-            toplevel or \
-            (compound_index == 0 and entry.get(
-                'need_result_map_for_compound', False))
+        populate_result_map = toplevel or \
+            (
+                compound_index == 0 and entry.get(
+                    'need_result_map_for_compound', False)
+            ) or entry.get('need_result_map_for_nested', False)
 
         if populate_result_map:
             self._ordered_columns = False
@@ -1478,9 +1500,8 @@ class SQLCompiler(Compiled):
     def visit_select(self, select, asfrom=False, parens=True,
                      fromhints=None,
                      compound_index=0,
-                     force_result_map=False,
                      nested_join_translation=False,
-                     _select_wraps=None,
+                     select_wraps_for=None,
                      **kwargs):
 
         needs_nested_translation = \
@@ -1496,18 +1517,17 @@ class SQLCompiler(Compiled):
                 transformed_select, asfrom=asfrom, parens=parens,
                 fromhints=fromhints,
                 compound_index=compound_index,
-                force_result_map=force_result_map,
                 nested_join_translation=True, **kwargs
             )
 
         toplevel = not self.stack
         entry = self._default_stack_entry if toplevel else self.stack[-1]
 
-        populate_result_map = force_result_map or \
-            toplevel or (
+        populate_result_map = toplevel or \
+            (
                 compound_index == 0 and entry.get(
                     'need_result_map_for_compound', False)
-            )
+            ) or entry.get('need_result_map_for_nested', False)
 
         if needs_nested_translation:
             if populate_result_map:
@@ -1552,10 +1572,10 @@ class SQLCompiler(Compiled):
             if c is not None
         ]
 
-        if populate_result_map and _select_wraps is not None:
+        if populate_result_map and select_wraps_for is not None:
             # if this select is a compiler-generated wrapper,
             # rewrite the targeted columns in the result map
-            wrapped_inner_columns = set(_select_wraps.inner_columns)
+            wrapped_inner_columns = set(select_wraps_for.inner_columns)
             translate = dict(
                 (outer, inner.pop()) for outer, inner in [
                     (
index 73c1402f6fe3ca7a1e747b95e1198d02c894ddc0..4b143c1509b0b72bb3a8cc9e5c530031a79bfad3 100644 (file)
@@ -3466,3 +3466,64 @@ class ResultMapTest(fixtures.TestBase):
             comp._create_result_map(),
             {'a': ('a', (aint, 'a', 'a'), aint.type)}
         )
+
+    def test_nested_api(self):
+        from sqlalchemy.engine.result import ResultMetaData
+        stmt2 = select([table2])
+
+        stmt1 = select([table1]).select_from(stmt2)
+
+        contexts = {}
+
+        int_ = Integer()
+
+        class MyCompiler(compiler.SQLCompiler):
+            def visit_select(self, stmt, *arg, **kw):
+
+                if stmt is stmt2:
+                    with self._nested_result() as nested:
+                        contexts[stmt2] = nested
+                        text = super(MyCompiler, self).visit_select(stmt2)
+                        self._add_to_result_map("k1", "k1", (1, 2, 3), int_)
+                else:
+                    text = super(MyCompiler, self).visit_select(
+                        stmt, *arg, **kw)
+                    self._add_to_result_map("k2", "k2", (3, 4, 5), int_)
+                return text
+
+        comp = MyCompiler(default.DefaultDialect(), stmt1)
+
+        eq_(
+            ResultMetaData._create_result_map(contexts[stmt2][0]),
+            {
+                'otherid': (
+                    'otherid',
+                    (table2.c.otherid, 'otherid', 'otherid'),
+                    table2.c.otherid.type),
+                'othername': (
+                    'othername',
+                    (table2.c.othername, 'othername', 'othername'),
+                    table2.c.othername.type),
+                'k1': ('k1', (1, 2, 3), int_)
+            }
+        )
+        eq_(
+            comp._create_result_map(),
+            {
+                'myid': (
+                    'myid',
+                    (table1.c.myid, 'myid', 'myid'), table1.c.myid.type
+                ),
+                'k2': ('k2', (3, 4, 5), int_),
+                'name': (
+                    'name', (table1.c.name, 'name', 'name'),
+                    table1.c.name.type),
+                'description': (
+                    'description',
+                    (table1.c.description, 'description', 'description'),
+                    table1.c.description.type)}
+        )
+
+
+
+