]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- since correlation is now always at least semi-automatic, remove the
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 9 Mar 2013 16:46:44 +0000 (11:46 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 9 Mar 2013 16:46:44 +0000 (11:46 -0500)
ability for correlation to have any effect for a SELECT that's stated
in the FROM.
- add a new exhaustive test suite for correlation to test_compiler

lib/sqlalchemy/sql/expression.py
test/sql/test_compiler.py

index 0ebcc1146557db0c11829d44ac6086615e7e5700..41eaace7107a14a4b7d7e02d04f0a3888a1776fd 100644 (file)
@@ -5258,34 +5258,36 @@ class Select(HasPrefixes, SelectBase):
             # using a list to maintain ordering
             froms = [f for f in froms if f not in toremove]
 
-        if self._correlate:
-            froms = [
-                f for f in froms if f not in
-                _cloned_intersection(
-                    _cloned_intersection(froms, existing_froms or ()),
-                    self._correlate
-                )
-            ]
-        if self._correlate_except:
-            froms = [
-                f for f in froms if f in
-                _cloned_intersection(
-                    froms,
-                    self._correlate_except
-                )
-            ]
-        if self._auto_correlate and existing_froms and len(froms) > 1 and not asfrom:
-            froms = [
-                f for f in froms if f not in
-                _cloned_intersection(froms, existing_froms)
-            ]
-
-            if not len(froms):
-                raise exc.InvalidRequestError("Select statement '%s"
-                        "' returned no FROM clauses due to "
-                        "auto-correlation; specify "
-                        "correlate(<tables>) to control "
-                        "correlation manually." % self)
+        if not asfrom:
+            if self._correlate:
+                froms = [
+                    f for f in froms if f not in
+                    _cloned_intersection(
+                        _cloned_intersection(froms, existing_froms or ()),
+                        self._correlate
+                    )
+                ]
+            if self._correlate_except:
+                froms = [
+                    f for f in froms if f in
+                    _cloned_intersection(
+                        froms,
+                        self._correlate_except
+                    )
+                ]
+
+            if self._auto_correlate and existing_froms and len(froms) > 1:
+                froms = [
+                    f for f in froms if f not in
+                    _cloned_intersection(froms, existing_froms)
+                ]
+
+                if not len(froms):
+                    raise exc.InvalidRequestError("Select statement '%s"
+                            "' returned no FROM clauses due to "
+                            "auto-correlation; specify "
+                            "correlate(<tables>) to control "
+                            "correlation manually." % self)
 
         return froms
 
index 22fecf6658e58b43495dd9f7a4237f6919104fee..fe52402eca19d88486afbee08ee96c3db451368a 100644 (file)
@@ -87,6 +87,7 @@ keyed = Table('keyed', metadata,
     Column('z', Integer),
 )
 
+
 class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'default'
 
@@ -424,39 +425,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
                     "AS z FROM keyed) AS anon_2) AS anon_1"
                     )
 
-    def test_dont_overcorrelate(self):
-        self.assert_compile(select([table1], from_obj=[table1,
-                            table1.select()]),
-                            "SELECT mytable.myid, mytable.name, "
-                            "mytable.description FROM mytable, (SELECT "
-                            "mytable.myid AS myid, mytable.name AS "
-                            "name, mytable.description AS description "
-                            "FROM mytable)")
-
-    def test_full_correlate(self):
-        # intentional
-        t = table('t', column('a'), column('b'))
-        s = select([t.c.a]).where(t.c.a == 1).correlate(t).as_scalar()
-
-        s2 = select([t.c.a, s])
-        self.assert_compile(s2,
-                "SELECT t.a, (SELECT t.a WHERE t.a = :a_1) AS anon_1 FROM t")
-
-        # unintentional
-        t2 = table('t2', column('c'), column('d'))
-        s = select([t.c.a]).where(t.c.a == t2.c.d).as_scalar()
-        s2 = select([t, t2, s])
-        assert_raises(exc.InvalidRequestError, str, s2)
-
-        # intentional again
-        s = s.correlate(t, t2)
-        s2 = select([t, t2, s])
-        self.assert_compile(
-            s2,
-            "SELECT t.a, t.b, t2.c, t2.d, a "
-            "FROM t, t2, (SELECT t.a AS a WHERE t.a = t2.d)"
-        )
-
     def test_exists(self):
         s = select([table1.c.myid]).where(table1.c.myid == 5)
 
@@ -3193,6 +3161,246 @@ class SchemaTest(fixtures.TestBase, AssertsCompiledSQL):
                     "(:rem_id, :datatype_id, :value)")
 
 
+class CorrelateTest(fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = 'default'
+
+    def test_dont_overcorrelate(self):
+        self.assert_compile(select([table1], from_obj=[table1,
+                            table1.select()]),
+                            "SELECT mytable.myid, mytable.name, "
+                            "mytable.description FROM mytable, (SELECT "
+                            "mytable.myid AS myid, mytable.name AS "
+                            "name, mytable.description AS description "
+                            "FROM mytable)")
+
+    def _fixture(self):
+        t1 = table('t1', column('a'))
+        t2 = table('t2', column('a'))
+        return t1, t2, select([t1]).where(t1.c.a == t2.c.a)
+
+    def _assert_where_correlated(self, stmt):
+        self.assert_compile(
+                stmt,
+                "SELECT t2.a FROM t2 WHERE t2.a = "
+                "(SELECT t1.a FROM t1 WHERE t1.a = t2.a)")
+
+    def _assert_where_all_correlated(self, stmt):
+        self.assert_compile(
+                stmt,
+                "SELECT t1.a, t2.a FROM t1, t2 WHERE t2.a = "
+                "(SELECT t1.a WHERE t1.a = t2.a)")
+
+    def _assert_where_backwards_correlated(self, stmt):
+        self.assert_compile(
+                stmt,
+                "SELECT t2.a FROM t2 WHERE t2.a = "
+                "(SELECT t1.a FROM t2 WHERE t1.a = t2.a)")
+
+    def _assert_column_correlated(self, stmt):
+        self.assert_compile(stmt,
+                "SELECT t2.a, (SELECT t1.a FROM t1 WHERE t1.a = t2.a) "
+                "AS anon_1 FROM t2")
+
+    def _assert_column_all_correlated(self, stmt):
+        self.assert_compile(stmt,
+                "SELECT t1.a, t2.a, "
+                "(SELECT t1.a WHERE t1.a = t2.a) AS anon_1 FROM t1, t2")
+
+    def _assert_column_backwards_correlated(self, stmt):
+        self.assert_compile(stmt,
+                "SELECT t2.a, (SELECT t1.a FROM t2 WHERE t1.a = t2.a) "
+                "AS anon_1 FROM t2")
+
+    def _assert_having_correlated(self, stmt):
+        self.assert_compile(stmt,
+                "SELECT t2.a FROM t2 HAVING t2.a = "
+                "(SELECT t1.a FROM t1 WHERE t1.a = t2.a)")
+
+    def _assert_from_uncorrelated(self, stmt):
+        self.assert_compile(stmt,
+                "SELECT t2.a, anon_1.a FROM t2, "
+                "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a) AS anon_1")
+
+    def _assert_from_all_uncorrelated(self, stmt):
+        self.assert_compile(stmt,
+                "SELECT t1.a, t2.a, anon_1.a FROM t1, t2, "
+                "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a) AS anon_1")
+
+    def _assert_where_uncorrelated(self, stmt):
+        self.assert_compile(stmt,
+                "SELECT t2.a FROM t2 WHERE t2.a = "
+                "(SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a)")
+
+    def _assert_column_uncorrelated(self, stmt):
+        self.assert_compile(stmt,
+                "SELECT t2.a, (SELECT t1.a FROM t1, t2 "
+                    "WHERE t1.a = t2.a) AS anon_1 FROM t2")
+
+    def _assert_having_uncorrelated(self, stmt):
+        self.assert_compile(stmt,
+                "SELECT t2.a FROM t2 HAVING t2.a = "
+                "(SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a)")
+
+    def _assert_where_single_full_correlated(self, stmt):
+        self.assert_compile(stmt,
+            "SELECT t1.a FROM t1 WHERE t1.a = (SELECT t1.a)")
+
+    def test_correlate_semiauto_where(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_where_correlated(
+                select([t2]).where(t2.c.a == s1.correlate(t2)))
+
+    def test_correlate_semiauto_column(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_column_correlated(
+                select([t2, s1.correlate(t2).as_scalar()]))
+
+    def test_correlate_semiauto_from(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_from_uncorrelated(
+                select([t2, s1.correlate(t2).alias()]))
+
+    def test_correlate_semiauto_having(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_having_correlated(
+                select([t2]).having(t2.c.a == s1.correlate(t2)))
+
+    def test_correlate_except_inclusion_where(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_where_correlated(
+                select([t2]).where(t2.c.a == s1.correlate_except(t1)))
+
+    def test_correlate_except_exclusion_where(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_where_backwards_correlated(
+                select([t2]).where(t2.c.a == s1.correlate_except(t2)))
+
+    def test_correlate_except_inclusion_column(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_column_correlated(
+                select([t2, s1.correlate_except(t1).as_scalar()]))
+
+    def test_correlate_except_exclusion_column(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_column_backwards_correlated(
+                select([t2, s1.correlate_except(t2).as_scalar()]))
+
+    def test_correlate_except_inclusion_from(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_from_uncorrelated(
+                select([t2, s1.correlate_except(t1).alias()]))
+
+    def test_correlate_except_exclusion_from(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_from_uncorrelated(
+                select([t2, s1.correlate_except(t2).alias()]))
+
+    def test_correlate_except_having(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_having_correlated(
+                select([t2]).having(t2.c.a == s1.correlate_except(t1)))
+
+    def test_correlate_auto_where(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_where_correlated(
+                select([t2]).where(t2.c.a == s1))
+
+    def test_correlate_auto_column(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_column_correlated(
+                select([t2, s1.as_scalar()]))
+
+    def test_correlate_auto_from(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_from_uncorrelated(
+                select([t2, s1.alias()]))
+
+    def test_correlate_auto_having(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_having_correlated(
+                select([t2]).having(t2.c.a == s1))
+
+    def test_correlate_disabled_where(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_where_uncorrelated(
+                select([t2]).where(t2.c.a == s1.correlate(None)))
+
+    def test_correlate_disabled_column(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_column_uncorrelated(
+                select([t2, s1.correlate(None).as_scalar()]))
+
+    def test_correlate_disabled_from(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_from_uncorrelated(
+                select([t2, s1.correlate(None).alias()]))
+
+    def test_correlate_disabled_having(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_having_uncorrelated(
+                select([t2]).having(t2.c.a == s1.correlate(None)))
+
+    def test_correlate_all_where(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_where_all_correlated(
+                select([t1, t2]).where(t2.c.a == s1.correlate(t1, t2)))
+
+    def test_correlate_all_column(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_column_all_correlated(
+                select([t1, t2, s1.correlate(t1, t2).as_scalar()]))
+
+    def test_correlate_all_from(self):
+        t1, t2, s1 = self._fixture()
+        self._assert_from_all_uncorrelated(
+                select([t1, t2, s1.correlate(t1, t2).alias()]))
+
+    def test_correlate_where_all_unintentional(self):
+        t1, t2, s1 = self._fixture()
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "returned no FROM clauses due to auto-correlation",
+            select([t1, t2]).where(t2.c.a == s1).compile
+        )
+
+    def test_correlate_from_all_ok(self):
+        t1, t2, s1 = self._fixture()
+        self.assert_compile(
+            select([t1, t2, s1]),
+            "SELECT t1.a, t2.a, a FROM t1, t2, "
+            "(SELECT t1.a AS a FROM t1, t2 WHERE t1.a = t2.a)"
+        )
+
+    def test_correlate_auto_where_singlefrom(self):
+        t1, t2, s1 = self._fixture()
+        s = select([t1.c.a])
+        s2 = select([t1]).where(t1.c.a == s)
+        self.assert_compile(s2,
+                "SELECT t1.a FROM t1 WHERE t1.a = "
+                "(SELECT t1.a FROM t1)")
+
+    def test_correlate_semiauto_where_singlefrom(self):
+        t1, t2, s1 = self._fixture()
+
+        s = select([t1.c.a])
+
+        s2 = select([t1]).where(t1.c.a == s.correlate(t1))
+        self._assert_where_single_full_correlated(s2)
+
+    def test_correlate_except_semiauto_where_singlefrom(self):
+        t1, t2, s1 = self._fixture()
+
+        s = select([t1.c.a])
+
+        s2 = select([t1]).where(t1.c.a == s.correlate_except(t2))
+        self._assert_where_single_full_correlated(s2)
+
+    def test_correlate_alone_noeffect(self):
+        # new as of #2668
+        t1, t2, s1 = self._fixture()
+        self.assert_compile(s1.correlate(t1, t2),
+            "SELECT t1.a FROM t1, t2 WHERE t1.a = t2.a")
+
 class CoercionTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = 'default'