]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement join rewriting inside of visit_select(). Currently this is global or not...
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Jun 2013 22:05:47 +0000 (18:05 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 2 Jun 2013 22:05:47 +0000 (18:05 -0400)
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py

index 41ef20a7a8ef2c48fb6bde5f72727aa4b649da2d..030d6dce93a651e0481584053dea4976f83cb4d0 100644 (file)
@@ -1077,23 +1077,64 @@ class SQLCompiler(engine.Compiled):
     def get_crud_hint_text(self, table, text):
         return None
 
+    def _transform_select_for_nested_joins(self, select):
+        adapters = []
+
+        traverse_options = {"cloned": {}}
+
+        def visit_join(elem):
+            if isinstance(elem.right, sql.FromGrouping):
+                selectable = sql.select([elem.right.element], use_labels=True)
+                selectable = selectable.alias()
+
+                while adapters:
+                    adapt = adapters.pop(-1)
+                    selectable = adapt.traverse(selectable)
+
+                for c in selectable.c:
+                    c._label = c._key_label = c.name
+
+                elem.right = selectable
+                adapters.append(
+                        sql_util.ClauseAdapter(selectable,
+                                        traverse_options=traverse_options)
+                )
+
+        select = visitors.cloned_traverse(select,
+                                    traverse_options, {"join": visit_join})
+
+        for adap in reversed(adapters):
+            select = adap.traverse(select)
+        return select
+
+    def _transform_result_map_for_nested_joins(self, select, transformed_select):
+        d = dict(zip(transformed_select.inner_columns, select.inner_columns))
+        for key, (name, objs, typ) in list(self.result_map.items()):
+            objs = tuple([d.get(col, col) for col in objs])
+            self.result_map[key] = (name, objs, typ)
+
     def visit_select(self, select, asfrom=False, parens=True,
                             iswrapper=False, fromhints=None,
                             compound_index=0,
                             force_result_map=False,
-                            positional_names=None, **kwargs):
-        entry = self.stack and self.stack[-1] or {}
+                            positional_names=None,
+                            nested_join_translation=False, **kwargs):
+
+        #nested_join_translation = True
+        if not nested_join_translation:
+            transformed_select = self._transform_select_for_nested_joins(select)
+            text = self.visit_select(
+                            transformed_select, asfrom=asfrom, parens=parens,
+                            iswrapper=iswrapper, fromhints=fromhints,
+                            compound_index=compound_index,
+                            force_result_map=force_result_map,
+                            positional_names=positional_names,
+                            nested_join_translation=True, **kwargs
+                        )
 
-        existingfroms = entry.get('from', None)
 
-        froms = select._get_display_froms(existingfroms, asfrom=asfrom)
 
-        correlate_froms = set(sql._from_objects(*froms))
-
-        # TODO: might want to propagate existing froms for
-        # select(select(select)) where innermost select should correlate
-        # to outermost if existingfroms: correlate_froms =
-        # correlate_froms.union(existingfroms)
+        entry = self.stack and self.stack[-1] or {}
 
         populate_result_map = force_result_map or (
                                         compound_index == 0 and (
@@ -1102,6 +1143,19 @@ class SQLCompiler(engine.Compiled):
                                         )
                                     )
 
+        if not nested_join_translation:
+            if populate_result_map:
+                self._transform_result_map_for_nested_joins(
+                                                select, transformed_select)
+            return text
+
+        existingfroms = entry.get('from', None)
+
+        froms = select._get_display_froms(existingfroms, asfrom=asfrom)
+
+        correlate_froms = set(sql._from_objects(*froms))
+
+
         self.stack.append({'from': correlate_froms,
                             'iswrapper': iswrapper})
 
index 91740dc16f024fd35351fd3d59c50c981d85c458..ffa07d3df5fcc875687156e521ed77fa8e72ec20 100644 (file)
@@ -797,8 +797,11 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
     def __init__(self, selectable, equivalents=None,
                         include=None, exclude=None,
                         include_fn=None, exclude_fn=None,
-                        adapt_on_names=False):
+                        adapt_on_names=False,
+                        traverse_options=None):
         self.__traverse_options__ = {'stop_on': [selectable]}
+        if traverse_options:
+            self.__traverse_options__.update(traverse_options)
         self.selectable = selectable
         if include:
             assert not include_fn
@@ -832,7 +835,7 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
     def replace(self, col):
         if isinstance(col, expression.FromClause) and \
             self.selectable.is_derived_from(col):
-                return self.selectable
+            return self.selectable
         elif not isinstance(col, expression.ColumnElement):
             return None
         elif self.include_fn and not self.include_fn(col):
index 62f46ab64cca5e22b58ad60848fbdd00be74f36c..31ac686e33d0c2925b133f465f9ed77ea0eeb274 100644 (file)
@@ -30,6 +30,7 @@ import operator
 __all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
     'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
     'iterate_depthfirst', 'traverse_using', 'traverse',
+    'traverse_depthfirst',
     'cloned_traverse', 'replacement_traverse']
 
 
@@ -255,7 +256,11 @@ def cloned_traverse(obj, opts, visitors):
     """clone the given expression structure, allowing
     modifications by visitors."""
 
-    cloned = util.column_dict()
+
+    if "cloned" in opts:
+        cloned = opts['cloned']
+    else:
+        cloned = util.column_dict()
     stop_on = util.column_set(opts.get('stop_on', []))
 
     def clone(elem):