]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the change for #918 was of course not nearly that simple.
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Mar 2015 03:51:12 +0000 (22:51 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 8 Mar 2015 04:03:27 +0000 (23:03 -0500)
The "wrapping" employed by the mssql and oracle dialects using the
"iswrapper" argument was not being used intelligently by the compiler,
and the result map was being written incorrectly, using
*more* columns in the result map than were actually returned by
the statement, due to "row number" columns that are inside the
subquery.   The compiler now writes out result map on the
"top level" select in all cases
fully, and for the mssql/oracle wrapping case extracts out
the "proxied" columns in a second step, which only includes
those columns that are proxied outwards to the top level.

This change might have implications for 3rd party dialects that
might be imitating oracle's approach.   They can safely continue
to use the "iswrapper" kw which is now ignored, but they may
need to also add the _select_wraps argument as well.

lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
test/dialect/mssql/test_compiler.py
test/dialect/test_oracle.py

index 92d7e4ab3105fbd19290a78cd06d4838297b1278..a35ab80d320df76e0d9b5dc5b66d1963092d9260 100644 (file)
@@ -1031,6 +1031,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
             _order_by_clauses = select._order_by_clause.clauses
             limit_clause = select._limit_clause
             offset_clause = select._offset_clause
+            kwargs['_select_wraps'] = select
             select = select._generate()
             select._mssql_visit = True
             select = select.column(
@@ -1048,7 +1049,7 @@ class MSSQLCompiler(compiler.SQLCompiler):
             else:
                 limitselect.append_whereclause(
                     mssql_rn <= (limit_clause))
-            return self.process(limitselect, iswrapper=True, **kwargs)
+            return self.process(limitselect, **kwargs)
         else:
             return compiler.SQLCompiler.visit_select(self, select, **kwargs)
 
index a5e071148a906174711eb5e479f0743dac4c6ef8..9ec84d268f7393f15e4dc70f932ac113a458292d 100644 (file)
@@ -665,8 +665,8 @@ class OracleCompiler(compiler.SQLCompiler):
         else:
             return sql.and_(*clauses)
 
-    def visit_outer_join_column(self, vc):
-        return self.process(vc.column) + "(+)"
+    def visit_outer_join_column(self, vc, **kw):
+        return self.process(vc.column, **kw) + "(+)"
 
     def visit_sequence(self, seq):
         return (self.dialect.identifier_preparer.format_sequence(seq) +
@@ -738,6 +738,7 @@ class OracleCompiler(compiler.SQLCompiler):
                 # limit=0
 
                 # TODO: use annotations instead of clone + attr set ?
+                kwargs['_select_wraps'] = select
                 select = select._generate()
                 select._oracle_visit = True
 
@@ -794,7 +795,6 @@ class OracleCompiler(compiler.SQLCompiler):
                     offsetselect._for_update_arg = select._for_update_arg
                     select = offsetselect
 
-        kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
         return compiler.SQLCompiler.visit_select(self, select, **kwargs)
 
     def limit_clause(self, select, **kw):
index 9d2bbfb157a860e8d47dad3d8bae590c6b374e16..62469d7201cc3694377d986d65ac4f9c98e41dd1 100644 (file)
@@ -663,7 +663,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
         self.cursor = self.create_cursor()
         return self
 
-    @property
+    @util.memoized_property
     def result_map(self):
         if self._result_columns:
             return self.compiled.result_map
index 61b6d22d0f7e5bb0ff8da021366e2dea168725f7..e37fa646c29168d5d14653ba6f251c5c61471339 100644 (file)
@@ -683,20 +683,18 @@ class SQLCompiler(Compiled):
                 self.post_process_text(textclause.text))
         )
 
-    def visit_text_as_from(self, taf, iswrapper=False,
-                           compound_index=0, force_result_map=False,
+    def visit_text_as_from(self, taf,
+                           compound_index=None, force_result_map=False,
                            asfrom=False,
                            parens=True, **kw):
 
         toplevel = not self.stack
         entry = self._default_stack_entry if toplevel else self.stack[-1]
 
-        populate_result_map = force_result_map or (
-            compound_index == 0 and (
-                toplevel or
-                entry['iswrapper']
-            )
-        )
+        populate_result_map = force_result_map or \
+            toplevel or \
+            (compound_index == 0 and entry.get(
+                'need_result_map_for_compound', False))
 
         if populate_result_map:
             self._ordered_columns = False
@@ -812,13 +810,16 @@ class SQLCompiler(Compiled):
                               parens=True, compound_index=0, **kwargs):
         toplevel = not self.stack
         entry = self._default_stack_entry if toplevel else self.stack[-1]
+        need_result_map = toplevel or \
+            (compound_index == 0
+                and entry.get('need_result_map_for_compound', False))
 
         self.stack.append(
             {
                 'correlate_froms': entry['correlate_froms'],
-                'iswrapper': toplevel,
                 'asfrom_froms': entry['asfrom_froms'],
-                'selectable': cs
+                'selectable': cs,
+                'need_result_map_for_compound': need_result_map
             })
 
         keyword = self.compound_keywords.get(cs.keyword)
@@ -840,8 +841,7 @@ class SQLCompiler(Compiled):
                  or cs._offset_clause is not None) and \
             self.limit_clause(cs, **kwargs) or ""
 
-        if self.ctes and \
-                compound_index == 0 and toplevel:
+        if self.ctes and toplevel:
             text = self._render_cte_clause() + text
 
         self.stack.pop(-1)
@@ -1460,7 +1460,6 @@ class SQLCompiler(Compiled):
         ]
 
     _default_stack_entry = util.immutabledict([
-        ('iswrapper', False),
         ('correlate_froms', frozenset()),
         ('asfrom_froms', frozenset())
     ])
@@ -1488,10 +1487,11 @@ class SQLCompiler(Compiled):
         return froms
 
     def visit_select(self, select, asfrom=False, parens=True,
-                     iswrapper=False, fromhints=None,
+                     fromhints=None,
                      compound_index=0,
                      force_result_map=False,
                      nested_join_translation=False,
+                     _select_wraps=None,
                      **kwargs):
 
         needs_nested_translation = \
@@ -1505,7 +1505,7 @@ class SQLCompiler(Compiled):
                 select)
             text = self.visit_select(
                 transformed_select, asfrom=asfrom, parens=parens,
-                iswrapper=iswrapper, fromhints=fromhints,
+                fromhints=fromhints,
                 compound_index=compound_index,
                 force_result_map=force_result_map,
                 nested_join_translation=True, **kwargs
@@ -1514,12 +1514,11 @@ class SQLCompiler(Compiled):
         toplevel = not self.stack
         entry = self._default_stack_entry if toplevel else self.stack[-1]
 
-        populate_result_map = force_result_map or (
-            compound_index == 0 and (
-                toplevel or
-                entry['iswrapper']
+        populate_result_map = force_result_map or \
+            toplevel or (
+                compound_index == 0 and entry.get(
+                    'need_result_map_for_compound', False)
             )
-        )
 
         if needs_nested_translation:
             if populate_result_map:
@@ -1527,7 +1526,7 @@ class SQLCompiler(Compiled):
                     select, transformed_select)
             return text
 
-        froms = self._setup_select_stack(select, entry, asfrom, iswrapper)
+        froms = self._setup_select_stack(select, entry, asfrom)
 
         column_clause_args = kwargs.copy()
         column_clause_args.update({
@@ -1553,16 +1552,34 @@ class SQLCompiler(Compiled):
         # the actual list of columns to print in the SELECT column list.
         inner_columns = [
             c for c in [
-                self._label_select_column(select,
-                                          column,
-                                          populate_result_map, asfrom,
-                                          column_clause_args,
-                                          name=name)
+                self._label_select_column(
+                    select,
+                    column,
+                    populate_result_map, asfrom,
+                    column_clause_args,
+                    name=name)
                 for name, column in select._columns_plus_names
             ]
             if c is not None
         ]
 
+        if populate_result_map and _select_wraps is not None:
+            # if this select is a compiler-generated wrapper,
+            # rewrite the targeted columns in the result map
+            wrapped_inner_columns = set(_select_wraps.inner_columns)
+            translate = dict(
+                (outer, inner.pop()) for outer, inner in [
+                    (
+                        outer,
+                        outer.proxy_set.intersection(wrapped_inner_columns))
+                    for outer in select.inner_columns
+                ] if inner
+            )
+            self._result_columns = [
+                (key, name, tuple(translate.get(o, o) for o in obj), type_)
+                for key, name, obj, type_ in self._result_columns
+            ]
+
         text = self._compose_select_body(
             text, select, inner_columns, froms, byfrom, kwargs)
 
@@ -1575,8 +1592,7 @@ class SQLCompiler(Compiled):
             if per_dialect:
                 text += " " + self.get_statement_hint_text(per_dialect)
 
-        if self.ctes and \
-                compound_index == 0 and toplevel:
+        if self.ctes and toplevel:
             text = self._render_cte_clause() + text
 
         if select._suffixes:
@@ -1603,7 +1619,7 @@ class SQLCompiler(Compiled):
         hint_text = self.get_select_hint_text(byfrom)
         return hint_text, byfrom
 
-    def _setup_select_stack(self, select, entry, asfrom, iswrapper):
+    def _setup_select_stack(self, select, entry, asfrom):
         correlate_froms = entry['correlate_froms']
         asfrom_froms = entry['asfrom_froms']
 
@@ -1622,7 +1638,6 @@ class SQLCompiler(Compiled):
 
         new_entry = {
             'asfrom_froms': new_correlate_froms,
-            'iswrapper': iswrapper,
             'correlate_froms': all_correlate_froms,
             'selectable': select,
         }
@@ -1765,7 +1780,6 @@ class SQLCompiler(Compiled):
     def visit_insert(self, insert_stmt, **kw):
         self.stack.append(
             {'correlate_froms': set(),
-             "iswrapper": False,
              "asfrom_froms": set(),
              "selectable": insert_stmt})
 
@@ -1889,7 +1903,6 @@ class SQLCompiler(Compiled):
     def visit_update(self, update_stmt, **kw):
         self.stack.append(
             {'correlate_froms': set([update_stmt.table]),
-             "iswrapper": False,
              "asfrom_froms": set([update_stmt.table]),
              "selectable": update_stmt})
 
@@ -1975,7 +1988,6 @@ class SQLCompiler(Compiled):
 
     def visit_delete(self, delete_stmt, **kw):
         self.stack.append({'correlate_froms': set([delete_stmt.table]),
-                           "iswrapper": False,
                            "asfrom_froms": set([delete_stmt.table]),
                            "selectable": delete_stmt})
         self.isdelete = True
index 3de8ea5c99ccdb1c1325958fd0b926b4ef9f49a7..54a23ee6e07f8e4105aae5a0b0b44c0d0fc74dd6 100644 (file)
@@ -416,11 +416,14 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT TOP 0 t.x, t.y FROM t WHERE t.x = :x_1 ORDER BY t.y",
             checkparams={'x_1': 5}
         )
+        c = s.compile(dialect=mssql.MSDialect())
+        eq_(len(c._result_columns), 2)
+        assert t.c.x in set(c.result_map['x'][1])
 
     def test_offset_using_window(self):
         t = table('t', column('x', Integer), column('y', Integer))
 
-        s = select([t]).where(t.c.x==5).order_by(t.c.y).offset(20)
+        s = select([t]).where(t.c.x == 5).order_by(t.c.y).offset(20)
 
         # test that the select is not altered with subsequent compile
         # calls
@@ -434,6 +437,10 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                 checkparams={'param_1': 20, 'x_1': 5}
             )
 
+            c = s.compile(dialect=mssql.MSDialect())
+            eq_(len(c._result_columns), 2)
+            assert t.c.x in set(c.result_map['x'][1])
+
     def test_limit_offset_using_window(self):
         t = table('t', column('x', Integer), column('y', Integer))
 
@@ -449,6 +456,10 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1",
             checkparams={'param_1': 20, 'param_2': 10, 'x_1': 5}
         )
+        c = s.compile(dialect=mssql.MSDialect())
+        eq_(len(c._result_columns), 2)
+        assert t.c.x in set(c.result_map['x'][1])
+        assert t.c.y in set(c.result_map['y'][1])
 
     def test_limit_offset_with_correlated_order_by(self):
         t1 = table('t1', column('x', Integer), column('y', Integer))
@@ -471,6 +482,11 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             checkparams={'param_1': 20, 'param_2': 10, 'x_1': 5}
         )
 
+        c = s.compile(dialect=mssql.MSDialect())
+        eq_(len(c._result_columns), 2)
+        assert t1.c.x in set(c.result_map['x'][1])
+        assert t1.c.y in set(c.result_map['y'][1])
+
     def test_limit_zero_offset_using_window(self):
         t = table('t', column('x', Integer), column('y', Integer))
 
index 3c67f15903c1c94e9fa8495a7c0ba13dbfc8fb6a..58ea058c26b6478291bd7f32869406bab70dc91e 100644 (file)
@@ -240,9 +240,11 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                             checkparams={'param_1': 10, 'param_2': 20})
 
         c = s.compile(dialect=oracle.OracleDialect())
+        eq_(len(c._result_columns), 2)
         assert t.c.col1 in set(c.result_map['col1'][1])
-        s = select([s.c.col1, s.c.col2])
-        self.assert_compile(s,
+
+        s2 = select([s.c.col1, s.c.col2])
+        self.assert_compile(s2,
                             'SELECT col1, col2 FROM (SELECT col1, col2 '
                             'FROM (SELECT col1, col2, ROWNUM AS ora_rn '
                             'FROM (SELECT sometable.col1 AS col1, '
@@ -251,13 +253,16 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                             ':param_2)',
                             checkparams={'param_1': 10, 'param_2': 20})
 
-        self.assert_compile(s,
+        self.assert_compile(s2,
                             'SELECT col1, col2 FROM (SELECT col1, col2 '
                             'FROM (SELECT col1, col2, ROWNUM AS ora_rn '
                             'FROM (SELECT sometable.col1 AS col1, '
                             'sometable.col2 AS col2 FROM sometable) '
                             'WHERE ROWNUM <= :param_1 + :param_2) WHERE ora_rn > '
                             ':param_2)')
+        c = s2.compile(dialect=oracle.OracleDialect())
+        eq_(len(c._result_columns), 2)
+        assert s.c.col1 in set(c.result_map['col1'][1])
 
         s = select([t]).limit(10).offset(20).order_by(t.c.col2)
         self.assert_compile(s,
@@ -269,6 +274,9 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
                             ':param_1 + :param_2) WHERE ora_rn > :param_2',
                             checkparams={'param_1': 10, 'param_2': 20}
                             )
+        c = s.compile(dialect=oracle.OracleDialect())
+        eq_(len(c._result_columns), 2)
+        assert t.c.col1 in set(c.result_map['col1'][1])
 
         s = select([t], for_update=True).limit(10).order_by(t.c.col2)
         self.assert_compile(s,