]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- rework the previous "order by" system in terms of the new one,
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Sep 2014 20:31:11 +0000 (16:31 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 8 Sep 2014 20:31:11 +0000 (16:31 -0400)
unify everything.
- create a new layer of separation between the "from order bys" and "column order bys",
so that an OVER doesn't ORDER BY a label in the same columns clause
- identify another issue with polymorphic for ref #3148, match on label
keys rather than the objects

lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/util.py
test/orm/test_query.py
test/sql/test_compiler.py

index 72dd11eaf5077e98d2914549cb168aec40c7fa1d..5149fa4feebcdba26ea2433a3e3528da3823530f 100644 (file)
@@ -503,7 +503,35 @@ class SQLCompiler(Compiled):
     def visit_grouping(self, grouping, asfrom=False, **kwargs):
         return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
 
-    def visit_label_reference(self, element, **kwargs):
+    def visit_label_reference(
+            self, element, within_columns_clause=False, **kwargs):
+        if self.stack and self.dialect.supports_simple_order_by_label:
+            selectable = self.stack[-1]['selectable']
+
+            with_cols, only_froms = selectable._label_resolve_dict
+            if within_columns_clause:
+                resolve_dict = only_froms
+            else:
+                resolve_dict = with_cols
+
+            # this can be None in the case that a _label_reference()
+            # were subject to a replacement operation, in which case
+            # the replacement of the Label element may have changed
+            # to something else like a ColumnClause expression.
+            order_by_elem = element.element._order_by_label_element
+
+            if order_by_elem is not None and order_by_elem.name in \
+                    resolve_dict:
+
+                kwargs['render_label_as_label'] = \
+                    element.element._order_by_label_element
+
+        return self.process(
+            element.element, within_columns_clause=within_columns_clause,
+            **kwargs)
+
+    def visit_textual_label_reference(
+            self, element, within_columns_clause=False, **kwargs):
         if not self.stack:
             # compiling the element outside of the context of a SELECT
             return self.process(
@@ -511,19 +539,25 @@ class SQLCompiler(Compiled):
             )
 
         selectable = self.stack[-1]['selectable']
+        with_cols, only_froms = selectable._label_resolve_dict
+
         try:
-            col = selectable._label_resolve_dict[element.text]
+            if within_columns_clause:
+                col = only_froms[element.element]
+            else:
+                col = with_cols[element.element]
         except KeyError:
             # treat it like text()
             util.warn_limited(
                 "Can't resolve label reference %r; converting to text()",
-                util.ellipses_string(element.text))
+                util.ellipses_string(element.element))
             return self.process(
                 element._text_clause
             )
         else:
             kwargs['render_label_as_label'] = col
-            return self.process(col, **kwargs)
+            return self.process(
+                col, within_columns_clause=within_columns_clause, **kwargs)
 
     def visit_label(self, label,
                     add_to_result_map=None,
@@ -678,11 +712,7 @@ class SQLCompiler(Compiled):
         else:
             return "0"
 
-    def visit_clauselist(self, clauselist, order_by_select=None, **kw):
-        if order_by_select is not None:
-            return self._order_by_clauselist(
-                clauselist, order_by_select, **kw)
-
+    def visit_clauselist(self, clauselist, **kw):
         sep = clauselist.operator
         if sep is None:
             sep = " "
@@ -695,26 +725,6 @@ class SQLCompiler(Compiled):
                 for c in clauselist.clauses)
             if s)
 
-    def _order_by_clauselist(self, clauselist, order_by_select, **kw):
-        # look through raw columns collection for labels.
-        # note that its OK we aren't expanding tables and other selectables
-        # here; we can only add a label in the ORDER BY for an individual
-        # label expression in the columns clause.
-
-        raw_col = set(order_by_select._label_resolve_dict.keys())
-
-        return ", ".join(
-            s for s in
-            (
-                c._compiler_dispatch(
-                    self,
-                    render_label_as_label=c._order_by_label_element if
-                    c._order_by_label_element is not None and
-                    c._order_by_label_element._label in raw_col
-                    else None,
-                    **kw)
-                for c in clauselist.clauses)
-            if s)
 
     def visit_case(self, clause, **kwargs):
         x = "CASE "
@@ -1590,13 +1600,7 @@ class SQLCompiler(Compiled):
                 text += " \nHAVING " + t
 
         if select._order_by_clause.clauses:
-            if self.dialect.supports_simple_order_by_label:
-                order_by_select = select
-            else:
-                order_by_select = None
-
-            text += self.order_by_clause(
-                select, order_by_select=order_by_select, **kwargs)
+            text += self.order_by_clause(select, **kwargs)
 
         if (select._limit_clause is not None or
                 select._offset_clause is not None):
index cf8de936df39ab731d9aa78f9fd42bc676cf2534..8ec0aa7002d74eaf2c21b6ffe9c1fda463c0a90f 100644 (file)
@@ -2356,14 +2356,39 @@ class Extract(ColumnElement):
 
 
 class _label_reference(ColumnElement):
+    """Wrap a column expression as it appears in a 'reference' context.
+
+    This expression is any that inclues an _order_by_label_element,
+    which is a Label, or a DESC / ASC construct wrapping a Label.
+
+    The production of _label_reference() should occur when an expression
+    is added to this context; this includes the ORDER BY or GROUP BY of a
+    SELECT statement, as well as a few other places, such as the ORDER BY
+    within an OVER clause.
+
+    """
     __visit_name__ = 'label_reference'
 
-    def __init__(self, text):
-        self.text = self.key = text
+    def __init__(self, element):
+        self.element = element
+
+    def _copy_internals(self, clone=_clone, **kw):
+        self.element = clone(self.element, **kw)
+
+    @property
+    def _from_objects(self):
+        return ()
+
+
+class _textual_label_reference(ColumnElement):
+    __visit_name__ = 'textual_label_reference'
+
+    def __init__(self, element):
+        self.element = element
 
     @util.memoized_property
     def _text_clause(self):
-        return TextClause._create_text(self.text)
+        return TextClause._create_text(self.element)
 
 
 class UnaryExpression(ColumnElement):
@@ -3556,6 +3581,13 @@ def _clause_element_as_expr(element):
 
 def _literal_as_label_reference(element):
     if isinstance(element, util.string_types):
+        return _textual_label_reference(element)
+
+    elif hasattr(element, '__clause_element__'):
+        element = element.__clause_element__()
+
+    if isinstance(element, ColumnElement) and \
+            element._order_by_label_element is not None:
         return _label_reference(element)
     else:
         return _literal_as_text(element)
index 57b16f45f36c5da5b17d92e783a81ec09cccae0a..0f29263502747b583c21ea9de82a8c8262131b4e 100644 (file)
@@ -1885,9 +1885,10 @@ class CompoundSelect(GenerativeSelect):
 
     @property
     def _label_resolve_dict(self):
-        return dict(
+        d = dict(
             (c.key, c) for c in self.c
         )
+        return d, d
 
     @classmethod
     def _create_union(cls, *selects, **kwargs):
@@ -2499,15 +2500,16 @@ class Select(HasPrefixes, GenerativeSelect):
 
     @_memoized_property
     def _label_resolve_dict(self):
-        d = dict(
+        with_cols = dict(
             (c._resolve_label or c._label or c.key, c)
             for c in _select_iterables(self._raw_columns)
             if c._allow_label_resolve)
-        d.update(
+        only_froms = dict(
             (c.key, c) for c in
             _select_iterables(self.froms) if c._allow_label_resolve)
+        with_cols.update(only_froms)
 
-        return d
+        return with_cols, only_froms
 
     def is_derived_from(self, fromclause):
         if self in fromclause._cloned_set:
index d6f3b591594ab5248bfa59f4ec78d6fc457b1d6e..fbbe15da3fd2e33b774270360b3038f9f8ca5e39 100644 (file)
@@ -16,7 +16,7 @@ from itertools import chain
 from collections import deque
 
 from .elements import BindParameter, ColumnClause, ColumnElement, \
-    Null, UnaryExpression, literal_column, Label
+    Null, UnaryExpression, literal_column, Label, _label_reference
 from .selectable import ScalarSelect, Join, FromClause, FromGrouping
 from .schema import Column
 
@@ -161,6 +161,8 @@ def unwrap_order_by(clause):
                 not isinstance(t, UnaryExpression) or
                 not operators.is_ordering_modifier(t.modifier)
         ):
+            if isinstance(t, _label_reference):
+                t = t.element
             cols.add(t)
         else:
             for c in t.get_children():
index 3f6813138247f413a8ce82949219c09ca5d71dca..c9f0a5db0febc8c9bf71016d5b9d372307053082 100644 (file)
@@ -1236,7 +1236,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL):
     __dialect__ = 'default'
     run_setup_mappers = 'each'
 
-    def _fixture(self, label=True):
+    def _fixture(self, label=True, polymorphic=False):
         User, Address = self.classes("User", "Address")
         users, addresses = self.tables("users", "addresses")
         stmt = select([func.max(addresses.c.email_address)]).\
@@ -1247,7 +1247,7 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL):
 
         mapper(User, users, properties={
             "ead": column_property(stmt)
-        })
+        }, with_polymorphic="*" if polymorphic else None)
         mapper(Address, addresses)
 
     def test_order_by_column_prop_string(self):
@@ -1355,6 +1355,22 @@ class ColumnPropertyTest(_fixtures.FixtureTest, AssertsCompiledSQL):
             "users AS users_1 ORDER BY email_ad, anon_1"
         )
 
+    def test_order_by_column_labeled_prop_attr_aliased_four(self):
+        User = self.classes.User
+        self._fixture(label=True, polymorphic=True)
+
+        ua = aliased(User)
+        s = Session()
+        q = s.query(ua, User.id).order_by(ua.ead)
+        self.assert_compile(
+            q,
+            "SELECT (SELECT max(addresses.email_address) AS max_1 FROM "
+            "addresses WHERE addresses.user_id = users_1.id) AS anon_1, "
+            "users_1.id AS users_1_id, users_1.name AS users_1_name, "
+            "users.id AS users_id FROM users AS users_1, users ORDER BY anon_1"
+        )
+
+
     def test_order_by_column_unlabeled_prop_attr_aliased_one(self):
         User = self.classes.User
         self._fixture(label=False)
index 4f8ced72c00c2b0c73274e1c5814d0864d262f94..d47b58f1f05c6441d6cb70be4c1e99a0a97ff194 100644 (file)
@@ -2169,6 +2169,27 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT x + foo() OVER () AS anon_1"
         )
 
+        # test a reference to a label that in the referecned selectable;
+        # this resolves
+        expr = (table1.c.myid + 5).label('sum')
+        stmt = select([expr]).alias()
+        self.assert_compile(
+            select([stmt.c.sum, func.row_number().over(order_by=stmt.c.sum)]),
+            "SELECT anon_1.sum, row_number() OVER (ORDER BY anon_1.sum) "
+            "AS anon_2 FROM (SELECT mytable.myid + :myid_1 AS sum "
+            "FROM mytable) AS anon_1"
+        )
+
+        # test a reference to a label that's at the same level as the OVER
+        # in the columns clause; doesn't resolve
+        expr = (table1.c.myid + 5).label('sum')
+        self.assert_compile(
+            select([expr, func.row_number().over(order_by=expr)]),
+            "SELECT mytable.myid + :myid_1 AS sum, "
+            "row_number() OVER "
+            "(ORDER BY mytable.myid + :myid_1) AS anon_1 FROM mytable"
+        )
+
     def test_date_between(self):
         import datetime
         table = Table('dt', metadata,