]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Changed behavior of Select.correlate() to ignore correlations to froms that don't...
authorLuke Cyca <me@lukecyca.com>
Thu, 7 Mar 2013 19:56:11 +0000 (11:56 -0800)
committerLuke Cyca <me@lukecyca.com>
Thu, 7 Mar 2013 19:56:11 +0000 (11:56 -0800)
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/orm/test_query.py
test/sql/test_compiler.py
test/sql/test_generative.py

index 59e46de122d1574f5c6c5eb58aa16de57d27b97f..90e9067277eb9e44823953ba51e19dacadef1e08 100644 (file)
@@ -1086,14 +1086,9 @@ class SQLCompiler(engine.Compiled):
                             positional_names=None, **kwargs):
         entry = self.stack and self.stack[-1] or {}
 
-        if not asfrom:
-            existingfroms = entry.get('from', None)
-        else:
-            # don't render correlations if we're rendering a FROM list
-            # entry
-            existingfroms = []
+        existingfroms = entry.get('from', None)
 
-        froms = select._get_display_froms(existingfroms)
+        froms = select._get_display_froms(existingfroms, asfrom=asfrom)
 
         correlate_froms = set(sql._from_objects(*froms))
 
index 490004e39002eefde799280d555cd5a887747add..0ebcc1146557db0c11829d44ac6086615e7e5700 100644 (file)
@@ -4980,7 +4980,7 @@ class CompoundSelect(SelectBase):
     INTERSECT_ALL = util.symbol('INTERSECT ALL')
 
     def __init__(self, keyword, *selects, **kwargs):
-        self._should_correlate = kwargs.pop('correlate', False)
+        self._auto_correlate = kwargs.pop('correlate', False)
         self.keyword = keyword
         self.selects = []
 
@@ -5159,7 +5159,7 @@ class Select(HasPrefixes, SelectBase):
         :class:`SelectBase` superclass.
 
         """
-        self._should_correlate = correlate
+        self._auto_correlate = correlate
         if distinct is not False:
             if distinct is True:
                 self._distinct = True
@@ -5232,7 +5232,7 @@ class Select(HasPrefixes, SelectBase):
 
         return froms
 
-    def _get_display_froms(self, existing_froms=None):
+    def _get_display_froms(self, existing_froms=None, asfrom=False):
         """Return the full list of 'from' clauses to be displayed.
 
         Takes into account a set of existing froms which may be
@@ -5258,25 +5258,34 @@ class Select(HasPrefixes, SelectBase):
             # using a list to maintain ordering
             froms = [f for f in froms if f not in toremove]
 
-        if len(froms) > 1 or self._correlate or self._correlate_except:
-            if self._correlate:
-                froms = [f for f in froms if f not in
-                        _cloned_intersection(froms,
-                        self._correlate)]
-            if self._correlate_except:
-                froms = [f for f in froms if f in _cloned_intersection(froms,
-                        self._correlate_except)]
-            if self._should_correlate and existing_froms:
-                froms = [f for f in froms if f not in
-                        _cloned_intersection(froms,
-                        existing_froms)]
-
-                if not len(froms):
-                    raise exc.InvalidRequestError("Select statement '%s"
-                            "' returned no FROM clauses due to "
-                            "auto-correlation; specify "
-                            "correlate(<tables>) to control "
-                            "correlation manually." % self)
+        if self._correlate:
+            froms = [
+                f for f in froms if f not in
+                _cloned_intersection(
+                    _cloned_intersection(froms, existing_froms or ()),
+                    self._correlate
+                )
+            ]
+        if self._correlate_except:
+            froms = [
+                f for f in froms if f in
+                _cloned_intersection(
+                    froms,
+                    self._correlate_except
+                )
+            ]
+        if self._auto_correlate and existing_froms and len(froms) > 1 and not asfrom:
+            froms = [
+                f for f in froms if f not in
+                _cloned_intersection(froms, existing_froms)
+            ]
+
+            if not len(froms):
+                raise exc.InvalidRequestError("Select statement '%s"
+                        "' returned no FROM clauses due to "
+                        "auto-correlation; specify "
+                        "correlate(<tables>) to control "
+                        "correlation manually." % self)
 
         return froms
 
@@ -5642,7 +5651,7 @@ class Select(HasPrefixes, SelectBase):
             :ref:`correlated_subqueries`
 
         """
-        self._should_correlate = False
+        self._auto_correlate = False
         if fromclauses and fromclauses[0] is None:
             self._correlate = ()
         else:
@@ -5662,7 +5671,7 @@ class Select(HasPrefixes, SelectBase):
             :ref:`correlated_subqueries`
 
         """
-        self._should_correlate = False
+        self._auto_correlate = False
         if fromclauses and fromclauses[0] is None:
             self._correlate_except = ()
         else:
@@ -5673,7 +5682,7 @@ class Select(HasPrefixes, SelectBase):
         """append the given correlation expression to this select()
         construct."""
 
-        self._should_correlate = False
+        self._auto_correlate = False
         self._correlate = set(self._correlate).union(
                 _interpret_as_from(f) for f in fromclause)
 
index f418d2581e90fabf269e789c2b6d4f7fedf5b8f3..be5d2b135048c408911614b40e96e351c3f948ee 100644 (file)
@@ -194,22 +194,28 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
         Address = self.classes.Address
 
         self.assert_compile(
-            select([User]).where(User.id == Address.user_id).
-                correlate(Address),
-            "SELECT users.id, users.name FROM users "
-            "WHERE users.id = addresses.user_id"
+            select([
+                User.name,
+                select([func.count(Address.id)
+            ]).where(User.id == Address.user_id).correlate(User)]),
+            "SELECT users.name, count_1 FROM users, "
+            "(SELECT count(addresses.id) AS count_1 "
+            "FROM addresses WHERE users.id = addresses.user_id)"
         )
 
     def test_correlate_aliased_entity(self):
         User = self.classes.User
         Address = self.classes.Address
-        aa = aliased(Address, name="aa")
+        uu = aliased(User, name="uu")
 
         self.assert_compile(
-            select([User]).where(User.id == aa.user_id).
-                correlate(aa),
-            "SELECT users.id, users.name FROM users "
-            "WHERE users.id = aa.user_id"
+            select([
+                uu.name,
+                select([func.count(Address.id)
+            ]).where(uu.id == Address.user_id).correlate(uu)]),
+            "SELECT uu.name, count_1 FROM users AS uu, "
+            "(SELECT count(addresses.id) AS count_1 "
+            "FROM addresses WHERE addresses.user_id = uu.id)"
         )
 
     def test_columns_clause_entity(self):
index 3b8aed23f6aca045853d60618a2e6c4dea196f7f..22fecf6658e58b43495dd9f7a4237f6919104fee 100644 (file)
@@ -451,7 +451,11 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         # intentional again
         s = s.correlate(t, t2)
         s2 = select([t, t2, s])
-        self.assert_compile(s, "SELECT t.a WHERE t.a = t2.d")
+        self.assert_compile(
+            s2,
+            "SELECT t.a, t.b, t2.c, t2.d, a "
+            "FROM t, t2, (SELECT t.a AS a WHERE t.a = t2.d)"
+        )
 
     def test_exists(self):
         s = select([table1.c.myid]).where(table1.c.myid == 5)
@@ -3315,4 +3319,4 @@ class ResultMapTest(fixtures.TestBase):
         )
         is_(
             comp.result_map['t1_a'][1][2], t1.c.a
-        )
\ No newline at end of file
+        )
index e868cbe885c4eb551a3df73034afa8c5470c8ba4..b43761f6ff2e0fd0f134778aed394f1242aab06f 100644 (file)
@@ -588,15 +588,24 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
         assert orig == str(s) == str(s5)
 
     def test_correlated_select(self):
-        s = select(['*'], t1.c.col1 == t2.c.col1,
-                    from_obj=[t1, t2]).correlate(t2)
+        s = select(
+            [func.count(t1.c.col1)],
+            t1.c.col1 == t2.c.col1,
+            from_obj=[t1, t2]
+        ).correlate(t2)
+
         class Vis(CloningVisitor):
             def visit_select(self, select):
                 select.append_whereclause(t1.c.col2 == 7)
 
-        self.assert_compile(Vis().traverse(s),
-                    "SELECT * FROM table1 WHERE table1.col1 = table2.col1 "
-                    "AND table1.col2 = :col2_1")
+        supers = select([t2, Vis().traverse(s)])
+
+        self.assert_compile(supers,
+                    "SELECT table2.col1, table2.col2, table2.col3, "
+                    "count_1 FROM table2, "
+                    "(SELECT count(table1.col1) AS count_1 "
+                    "FROM table1 WHERE table1.col1 = table2.col1 "
+                    "AND table1.col2 = :col2_1)")
 
     def test_this_thing(self):
         s = select([t1]).where(t1.c.col1 == 'foo').alias()
@@ -619,32 +628,32 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_select_fromtwice(self):
         t1a = t1.alias()
 
-        s = select([1], t1.c.col1 == t1a.c.col1, from_obj=t1a).correlate(t1)
+        s = select([1], t1.c.col1 == t1a.c.col1, from_obj=t1a)
         self.assert_compile(s,
-                            'SELECT 1 FROM table1 AS table1_1 WHERE '
+                            'SELECT 1 FROM table1, table1 AS table1_1 WHERE '
                             'table1.col1 = table1_1.col1')
 
         s = CloningVisitor().traverse(s)
         self.assert_compile(s,
-                            'SELECT 1 FROM table1 AS table1_1 WHERE '
+                            'SELECT 1 FROM table1, table1 AS table1_1 WHERE '
                             'table1.col1 = table1_1.col1')
 
-        s = select([t1]).where(t1.c.col1 == 'foo').alias()
+        s = select([t1]).where(t1.c.col1 == 'foo').correlate(t1).alias()
 
-        s2 = select([1], t1.c.col1 == s.c.col1, from_obj=s).correlate(t1)
+        s2 = select([1], t1.c.col1 == s.c.col1, from_obj=s)
         self.assert_compile(s2,
-                            'SELECT 1 FROM (SELECT table1.col1 AS '
-                            'col1, table1.col2 AS col2, table1.col3 AS '
-                            'col3 FROM table1 WHERE table1.col1 = '
-                            ':col1_1) AS anon_1 WHERE table1.col1 = '
-                            'anon_1.col1')
+                            'SELECT 1 FROM table1, '
+                            '(SELECT table1.col1 AS col1, '
+                            'table1.col2 AS col2, table1.col3 AS col3 '
+                            'WHERE table1.col1 = :col1_1) AS anon_1 '
+                            'WHERE table1.col1 = anon_1.col1')
         s2 = ReplacingCloningVisitor().traverse(s2)
         self.assert_compile(s2,
-                            'SELECT 1 FROM (SELECT table1.col1 AS '
-                            'col1, table1.col2 AS col2, table1.col3 AS '
-                            'col3 FROM table1 WHERE table1.col1 = '
-                            ':col1_1) AS anon_1 WHERE table1.col1 = '
-                            'anon_1.col1')
+                            'SELECT 1 FROM table1, '
+                            '(SELECT table1.col1 AS col1, '
+                            'table1.col2 AS col2, table1.col3 AS col3 '
+                            'WHERE table1.col1 = :col1_1) AS anon_1 '
+                            'WHERE table1.col1 = anon_1.col1')
 
 class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'default'
@@ -784,16 +793,15 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
                             == t2.c.col2, from_obj=[t1, t2])),
                             'SELECT * FROM table1 AS t1alias, table2 '
                             'WHERE t1alias.col1 = table2.col2')
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1
-                            == t2.c.col2, from_obj=[t1,
-                            t2]).correlate(t1)),
-                            'SELECT * FROM table2 WHERE t1alias.col1 = '
-                            'table2.col2')
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1
-                            == t2.c.col2, from_obj=[t1,
-                            t2]).correlate(t2)),
-                            'SELECT * FROM table1 AS t1alias WHERE '
-                            't1alias.col1 = table2.col2')
+        self.assert_compile(vis.traverse(select([
+                            t1, select([func.count(t2.c.col2)],
+                            t1.c.col1 == t2.c.col2, from_obj=[t1,
+                            t2]).correlate(t1)])),
+                            'SELECT t1alias.col1, t1alias.col2, '
+                            't1alias.col3, count_1 FROM table1 AS '
+                            't1alias, (SELECT count(table2.col2) '
+                            'AS count_1 FROM table2 '
+                            'WHERE t1alias.col1 = table2.col2)')
         self.assert_compile(vis.traverse(case([(t1.c.col1 == 5,
                             t1.c.col2)], else_=t1.c.col1)),
                             'CASE WHEN (t1alias.col1 = :col1_1) THEN '
@@ -836,16 +844,15 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL):
                             'SELECT * FROM table1 AS t1alias, table2 '
                             'AS t2alias WHERE t1alias.col1 = '
                             't2alias.col2')
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1
-                            == t2.c.col2, from_obj=[t1,
-                            t2]).correlate(t1)),
-                            'SELECT * FROM table2 AS t2alias WHERE '
-                            't1alias.col1 = t2alias.col2')
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1
-                            == t2.c.col2, from_obj=[t1,
-                            t2]).correlate(t2)),
-                            'SELECT * FROM table1 AS t1alias WHERE '
-                            't1alias.col1 = t2alias.col2')
+        self.assert_compile(vis.traverse(select([
+                            t1, select([func.count(t2.c.col2)],
+                            t1.c.col1 == t2.c.col2, from_obj=[t1,
+                            t2]).correlate(t1)])),
+                            'SELECT t1alias.col1, t1alias.col2, '
+                            't1alias.col3, count_1 FROM table1 AS '
+                            't1alias, (SELECT count(t2alias.col2) '
+                            'AS count_1 FROM table2 AS t2alias '
+                            'WHERE t1alias.col1 = t2alias.col2)')
 
     def test_include_exclude(self):
         m = MetaData()