]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed subtle bug that caused SQL to blow
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Jun 2011 23:25:35 +0000 (19:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Jun 2011 23:25:35 +0000 (19:25 -0400)
up if: column_property() against subquery +
joinedload + LIMIT + order by the column
property() occurred.  [ticket:2188].
Also in 0.6.9

CHANGES
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/sql/util.py
test/orm/test_eager_relations.py
test/sql/test_selectable.py

diff --git a/CHANGES b/CHANGES
index 9aaccb06744f1f8ffac728d1c6478dba826de84d..74e1beba6372acb44f40df7fa802e37b5f111523 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -6,6 +6,12 @@ CHANGES
 0.7.2
 =====
 - orm
+  - Fixed subtle bug that caused SQL to blow
+    up if: column_property() against subquery +
+    joinedload + LIMIT + order by the column
+    property() occurred.  [ticket:2188].
+    Also in 0.6.9
+
   - Added the same "columns-only" check to 
     mapper.polymorphic_on as used in 
     relationship.order_by, foreign_keys,
@@ -30,9 +36,11 @@ CHANGES
     [ticket:2199].  Also in 0.6.9.
 
 - sql
-  - Fixed a subtle bug involving column 
-    correspondence in a selectable with the
-    same column repeated.   Affects [ticket:2188].
+  - Fixed two subtle bugs involving column 
+    correspondence in a selectable,
+    one with the same labeled subquery repeated, the other
+    when the label has been "grouped" and 
+    loses itself.  Affects [ticket:2188].
 
 - engine
   - Use urllib.parse_qsl() in Python 2.6 and above,
index 1674cd6c6cebb3ab7fab7a748ec0e18b4c9b13ad..2a13ce32138fe33443fa77af0dd4436787f19b1c 100644 (file)
@@ -2466,7 +2466,7 @@ class Query(object):
             if context.order_by:
                 order_by_col_expr = list(
                                         chain(*[
-                                            sql_util.find_columns(o) 
+                                            sql_util.unwrap_order_by(o)
                                             for o in context.order_by
                                         ])
                                     )
@@ -2480,6 +2480,9 @@ class Query(object):
                         from_obj=froms,
                         use_labels=labels,
                         correlate=False,
+                        # TODO: this order_by is only needed if 
+                        # LIMIT/OFFSET is present in self._select_args,
+                        # else the application on the outside is enough
                         order_by=context.order_by,
                         **self._select_args
                     )
@@ -2513,11 +2516,11 @@ class Query(object):
             statement.append_from(from_clause)
 
             if context.order_by:
-                    statement.append_order_by(
-                        *context.adapter.copy_and_process(
-                            context.order_by
-                        )
+                statement.append_order_by(
+                    *context.adapter.copy_and_process(
+                        context.order_by
                     )
+                )
 
             statement.append_order_by(*context.eager_order_by)
         else:
@@ -2527,7 +2530,7 @@ class Query(object):
             if self._distinct and context.order_by:
                 order_by_col_expr = list(
                                         chain(*[
-                                            sql_util.find_columns(o) 
+                                            sql_util.unwrap_order_by(o) 
                                             for o in context.order_by
                                         ])
                                     )
index 9dd7bd335956c622f502e9e0b282a7d5db85097e..66a87c26feef412fecdcaeeb01802df04d35bda6 100644 (file)
@@ -3710,7 +3710,7 @@ class _Label(ColumnElement):
                         sub_element, 
                         type_=self._type)
         else:
-            return self._element
+            return self
 
     @property
     def primary_key(self):
index dcea5a0f6dd4ade5719b819c8e4c4ab8fc2617d8..db6c40e9a58ec69e956f96c4795754454ad2c2f7 100644 (file)
@@ -524,6 +524,10 @@ _commutative = set([eq, ne, add, mul])
 def is_commutative(op):
     return op in _commutative
 
+def is_ordering_modifier(op):
+    return op in (asc_op, desc_op, 
+                    nullsfirst_op, nullslast_op)
+
 _associative = _commutative.union([concat_op, and_, or_])
 
 
index 1a3f7d2f8fdf299395daf5de798143893734973b..f003a969189aaaa06b295c7bf43f249fca6cb9c2 100644 (file)
@@ -8,6 +8,7 @@ from sqlalchemy import exc, schema, util, sql, types as sqltypes
 from sqlalchemy.util import topological
 from sqlalchemy.sql import expression, operators, visitors
 from itertools import chain
+from collections import deque
 
 """Utility functions that build upon SQL and Schema constructs."""
 
@@ -99,6 +100,25 @@ def find_columns(clause):
     visitors.traverse(clause, {}, {'column':cols.add})
     return cols
 
+def unwrap_order_by(clause):
+    """Break up an 'order by' expression into individual column-expressions,
+    without DESC/ASC/NULLS FIRST/NULLS LAST"""
+
+    cols = util.column_set()
+    stack = deque([clause])
+    while stack:
+        t = stack.popleft()
+        if isinstance(t, expression.ColumnElement) and \
+            (
+                not isinstance(t, expression._UnaryExpression) or \
+                not operators.is_ordering_modifier(t.modifier)
+            ): 
+            cols.add(t)
+        else:
+            for c in t.get_children():
+                stack.append(c)
+    return cols
+
 def clause_is_present(clause, search):
     """Given a target clause and a second to search within, return True
     if the target is plainly present in the search without any
@@ -624,11 +644,15 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
         self.equivalents = util.column_dict(equivalents or {})
 
     def _corresponding_column(self, col, require_embedded, _seen=util.EMPTY_SET):
-        newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded)
+        newcol = self.selectable.corresponding_column(
+                                    col, 
+                                    require_embedded=require_embedded)
 
         if newcol is None and col in self.equivalents and col not in _seen:
             for equiv in self.equivalents[col]:
-                newcol = self._corresponding_column(equiv, require_embedded=require_embedded, _seen=_seen.union([col]))
+                newcol = self._corresponding_column(equiv, 
+                                require_embedded=require_embedded, 
+                                _seen=_seen.union([col]))
                 if newcol is not None:
                     return newcol
         return newcol
index 543b6df5b1f263669c3ffee91224319437753f61..e3914e96c8af07b2965f8d9f9f8f955e2dcd9d77 100644 (file)
@@ -9,7 +9,7 @@ from sqlalchemy import Integer, String, Date, ForeignKey, and_, select, \
     func
 from test.lib.schema import Table, Column
 from sqlalchemy.orm import mapper, relationship, create_session, \
-    lazyload, aliased
+    lazyload, aliased, column_property
 from test.lib.testing import eq_, assert_raises, \
     assert_raises_message
 from test.lib.assertsql import CompiledSQL
@@ -18,6 +18,7 @@ from test.orm import _fixtures
 from sqlalchemy.util import OrderedDict as odict
 import datetime
 
+
 class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
     run_inserts = 'once'
     run_deletes = None
@@ -1329,6 +1330,182 @@ class EagerTest(_fixtures.FixtureTest, testing.AssertsCompiledSQL):
             use_default_dialect=True
         )
 
+class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
+    """test #2188"""
+
+    __dialect__ = 'default'
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('a', metadata,
+            Column('id', Integer, primary_key=True)
+        )
+
+        Table('b', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('a_id', Integer, ForeignKey('a.id')),
+            Column('value', Integer),
+        )
+
+    @classmethod
+    def setup_classes(cls):
+
+        class A(cls.Comparable):
+            pass
+        class B(cls.Comparable):
+            pass
+
+    def _fixture(self, props):
+        A, B = self.classes.A, self.classes.B
+        b_table, a_table = self.tables.b, self.tables.a
+        mapper(A,a_table, properties=props)
+        mapper(B,b_table,properties = {
+            'a':relationship(A, backref="bs")
+        })
+
+    def test_column_property(self):
+        A, B = self.classes.A, self.classes.B
+        b_table, a_table = self.tables.b, self.tables.a
+        cp = select([func.sum(b_table.c.value)]).\
+                        where(b_table.c.a_id==a_table.c.id)
+
+        self._fixture({
+            'summation':column_property(cp)
+        })
+        self.assert_compile(
+            create_session().query(A).options(joinedload_all('bs')).
+                            order_by(A.summation).
+                            limit(50),
+            "SELECT anon_1.anon_2 AS anon_1_anon_2, anon_1.a_id "
+            "AS anon_1_a_id, b_1.id AS b_1_id, b_1.a_id AS "
+            "b_1_a_id, b_1.value AS b_1_value FROM (SELECT "
+            "(SELECT sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
+            "AS anon_2, a.id AS a_id FROM a ORDER BY (SELECT "
+            "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
+            "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 ON "
+            "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2"
+        )
+
+    def test_column_property_desc(self):
+        A, B = self.classes.A, self.classes.B
+        b_table, a_table = self.tables.b, self.tables.a
+        cp = select([func.sum(b_table.c.value)]).\
+                        where(b_table.c.a_id==a_table.c.id)
+
+        self._fixture({
+            'summation':column_property(cp)
+        })
+        self.assert_compile(
+            create_session().query(A).options(joinedload_all('bs')).
+                            order_by(A.summation.desc()).
+                            limit(50),
+            "SELECT anon_1.anon_2 AS anon_1_anon_2, anon_1.a_id "
+            "AS anon_1_a_id, b_1.id AS b_1_id, b_1.a_id AS "
+            "b_1_a_id, b_1.value AS b_1_value FROM (SELECT "
+            "(SELECT sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
+            "AS anon_2, a.id AS a_id FROM a ORDER BY (SELECT "
+            "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) DESC "
+            "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 ON "
+            "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2 DESC"
+        )
+
+    def test_column_property_correlated(self):
+        A, B = self.classes.A, self.classes.B
+        b_table, a_table = self.tables.b, self.tables.a
+        cp = select([func.sum(b_table.c.value)]).\
+                        where(b_table.c.a_id==a_table.c.id).\
+                        correlate(a_table)
+
+        self._fixture({
+            'summation':column_property(cp)
+        })
+        self.assert_compile(
+            create_session().query(A).options(joinedload_all('bs')).
+                            order_by(A.summation).
+                            limit(50),
+            "SELECT anon_1.anon_2 AS anon_1_anon_2, anon_1.a_id "
+            "AS anon_1_a_id, b_1.id AS b_1_id, b_1.a_id AS "
+            "b_1_a_id, b_1.value AS b_1_value FROM (SELECT "
+            "(SELECT sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
+            "AS anon_2, a.id AS a_id FROM a ORDER BY (SELECT "
+            "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
+            "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 ON "
+            "anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2"
+        )
+
+    def test_standalone_subquery_unlabeled(self):
+        A, B = self.classes.A, self.classes.B
+        b_table, a_table = self.tables.b, self.tables.a
+        self._fixture({})
+        cp = select([func.sum(b_table.c.value)]).\
+                        where(b_table.c.a_id==a_table.c.id).\
+                        correlate(a_table).as_scalar()
+        # note its re-rendering the subquery in the
+        # outermost order by.  usually we want it to address
+        # the column within the subquery.  labelling fixes that.
+        self.assert_compile(
+            create_session().query(A).options(joinedload_all('bs')).
+                            order_by(cp).
+                            limit(50),
+            "SELECT anon_1.a_id AS anon_1_a_id, anon_1.anon_2 "
+            "AS anon_1_anon_2, b_1.id AS b_1_id, b_1.a_id AS "
+            "b_1_a_id, b_1.value AS b_1_value FROM (SELECT a.id "
+            "AS a_id, (SELECT sum(b.value) AS sum_1 FROM b WHERE "
+            "b.a_id = a.id) AS anon_2 FROM a ORDER BY (SELECT "
+            "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
+            "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 "
+            "ON anon_1.a_id = b_1.a_id ORDER BY "
+            "(SELECT anon_1.anon_2 FROM b WHERE b.a_id = anon_1.a_id)"
+        )
+
+    def test_standalone_subquery_labeled(self):
+        A, B = self.classes.A, self.classes.B
+        b_table, a_table = self.tables.b, self.tables.a
+        self._fixture({})
+        cp = select([func.sum(b_table.c.value)]).\
+                        where(b_table.c.a_id==a_table.c.id).\
+                        correlate(a_table).as_scalar().label('foo')
+        self.assert_compile(
+            create_session().query(A).options(joinedload_all('bs')).
+                            order_by(cp).
+                            limit(50),
+            "SELECT anon_1.a_id AS anon_1_a_id, anon_1.foo "
+            "AS anon_1_foo, b_1.id AS b_1_id, b_1.a_id AS "
+            "b_1_a_id, b_1.value AS b_1_value FROM (SELECT a.id "
+            "AS a_id, (SELECT sum(b.value) AS sum_1 FROM b WHERE "
+            "b.a_id = a.id) AS foo FROM a ORDER BY (SELECT "
+            "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
+            "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 "
+            "ON anon_1.a_id = b_1.a_id ORDER BY "
+            "anon_1.foo"
+        )
+
+    def test_standalone_negated(self):
+        A, B = self.classes.A, self.classes.B
+        b_table, a_table = self.tables.b, self.tables.a
+        self._fixture({})
+        cp = select([func.sum(b_table.c.value)]).\
+                        where(b_table.c.a_id==a_table.c.id).\
+                        correlate(a_table).\
+                        as_scalar()
+        # test a different unary operator
+        self.assert_compile(
+            create_session().query(A).options(joinedload_all('bs')).
+                            order_by(~cp).
+                            limit(50),
+            "SELECT anon_1.a_id AS anon_1_a_id, anon_1.anon_2 "
+            "AS anon_1_anon_2, b_1.id AS b_1_id, b_1.a_id AS "
+            "b_1_a_id, b_1.value AS b_1_value FROM (SELECT a.id "
+            "AS a_id, NOT (SELECT sum(b.value) AS sum_1 FROM b "
+            "WHERE b.a_id = a.id) FROM a ORDER BY NOT (SELECT "
+            "sum(b.value) AS sum_1 FROM b WHERE b.a_id = a.id) "
+            "LIMIT :param_1) AS anon_1 LEFT OUTER JOIN b AS b_1 "
+            "ON anon_1.a_id = b_1.a_id ORDER BY anon_1.anon_2"
+        )
+
+
+
+
 class AddEntityTest(_fixtures.FixtureTest):
     run_inserts = 'once'
     run_deletes = None
index debdd0bb7d0d6912b041e63d647dfa6f925b7159..63be50a974e4fb5bc816b19ce6106902fa6a6c63 100644 (file)
@@ -64,7 +64,16 @@ class SelectableTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiled
 
         assert s1.corresponding_column(scalar_select) is s1.c.foo
         assert s2.corresponding_column(scalar_select) is s2.c.foo
-
+    
+    def test_label_grouped_still_corresponds(self):
+        label = select([table1.c.col1]).label('foo')
+        label2 = label.self_group()
+        
+        s1 = select([label])
+        s2 = select([label2])
+        assert s1.corresponding_column(label) is s1.c.foo
+        assert s2.corresponding_column(label) is s2.c.foo
+        
     def test_direct_correspondence_on_labels(self):
         # this test depends on labels being part
         # of the proxy set to get the right result