]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure we unwrap desc() /label() all the way w/ order by
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Jul 2020 15:26:39 +0000 (11:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 7 Jul 2020 16:39:13 +0000 (12:39 -0400)
The deprecated logic to move order_by expressions
up into the columns clause needed adjustment to accommodate
for a more deeply-wrapped structure when desc() + label()
 are combined in an order by column.  This structure
now comes from coercions in 1.4.  it's not clear to me
at the moment why it's different from 1.3 but
this shouldn't really matter.

Fixes: #5443
Change-Id: If909a86f715992318d7aa283603197f7711f1d3b

lib/sqlalchemy/sql/util.py
test/orm/test_deprecations.py
test/sql/test_utils.py

index e8726000b8caad0087a287ca021375c45f588f20..b803ef91276c418f0ab07c953e8a45a46d2c2d26 100644 (file)
@@ -27,6 +27,7 @@ from .elements import _textual_label_reference
 from .elements import BindParameter
 from .elements import ColumnClause
 from .elements import ColumnElement
+from .elements import Label
 from .elements import Null
 from .elements import UnaryExpression
 from .schema import Column
@@ -279,14 +280,31 @@ def unwrap_order_by(clause):
     cols = util.column_set()
     result = []
     stack = deque([clause])
+
+    # examples
+    # column -> ASC/DESC == column
+    # column -> ASC/DESC -> label == column
+    # column -> label -> ASC/DESC -> label == column
+    # scalar_select -> label -> ASC/DESC == scalar_select -> label
+
     while stack:
         t = stack.popleft()
         if isinstance(t, ColumnElement) and (
             not isinstance(t, UnaryExpression)
             or not operators.is_ordering_modifier(t.modifier)
         ):
-            if isinstance(t, _label_reference):
+            if isinstance(t, Label) and not isinstance(
+                t.element, ScalarSelect
+            ):
+                t = t.element
+
+                stack.append(t)
+                continue
+            elif isinstance(t, _label_reference):
                 t = t.element
+
+                stack.append(t)
+                continue
             if isinstance(t, (_textual_label_reference)):
                 continue
             if t not in cols:
index 4e9f50661ee17f2a57b158e685eee9ba25ffd4af..4cfded25c31f2d7aa03cebb444f36c50851c999e 100644 (file)
@@ -1597,6 +1597,21 @@ class DistinctOrderByImplicitTest(QueryTest, AssertsCompiledSQL):
         ):
             eq_([User(id=7), User(id=9), User(id=8)], q.all())
 
+    def test_columns_augmented_roundtrip_two(self):
+        User, Address = self.classes.User, self.classes.Address
+
+        sess = create_session()
+        q = (
+            sess.query(User)
+            .join("addresses")
+            .distinct()
+            .order_by(desc(Address.email_address).label("foo"))
+        )
+        with testing.expect_deprecated(
+            "ORDER BY columns added implicitly due to "
+        ):
+            eq_([User(id=7), User(id=9), User(id=8)], q.all())
+
     def test_columns_augmented_roundtrip_three(self):
         User, Address = self.classes.User, self.classes.Address
 
index d68a7447534fdc6d1859167629c9dddd34278802..676ad429820179267481ca5287bfd3ffa6d0fab2 100644 (file)
@@ -4,9 +4,14 @@ from sqlalchemy import MetaData
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
+from sqlalchemy import testing
+from sqlalchemy import util
 from sqlalchemy.sql import base as sql_base
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import column
+from sqlalchemy.sql import ColumnElement
+from sqlalchemy.sql import roles
 from sqlalchemy.sql import util as sql_util
-from sqlalchemy.sql.elements import ColumnElement
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
@@ -89,3 +94,28 @@ class MiscTest(fixtures.TestBase):
         eq_(o4.bat, "hi")
 
         assert_raises(TypeError, opt2.safe_merge, o4)
+
+    @testing.combinations(
+        (column("q"), [column("q")]),
+        (column("q").desc(), [column("q")]),
+        (column("q").desc().label(None), [column("q")]),
+        (column("q").label(None).desc(), [column("q")]),
+        (column("q").label(None).desc().label(None), [column("q")]),
+        ("foo", []),  # textual label reference
+        (
+            select([column("q")]).scalar_subquery().label(None),
+            [select([column("q")]).scalar_subquery().label(None)],
+        ),
+        (
+            select([column("q")]).scalar_subquery().label(None).desc(),
+            [select([column("q")]).scalar_subquery().label(None)],
+        ),
+    )
+    def test_unwrap_order_by(self, expr, expected):
+
+        expr = coercions.expect(roles.OrderByRole, expr)
+
+        unwrapped = sql_util.unwrap_order_by(expr)
+
+        for a, b in util.zip_longest(unwrapped, expected):
+            assert a is not None and a.compare(b)