From: Luke Cyca Date: Thu, 7 Mar 2013 19:56:11 +0000 (-0800) Subject: Changed behavior of Select.correlate() to ignore correlations to froms that don't... X-Git-Tag: rel_0_8_0~7^2~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f122a307e03d9a1f2322b35429972a5f928d5b30;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Changed behavior of Select.correlate() to ignore correlations to froms that don't exist in the superquery. --- diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 59e46de122..90e9067277 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1086,14 +1086,9 @@ class SQLCompiler(engine.Compiled): positional_names=None, **kwargs): entry = self.stack and self.stack[-1] or {} - if not asfrom: - existingfroms = entry.get('from', None) - else: - # don't render correlations if we're rendering a FROM list - # entry - existingfroms = [] + existingfroms = entry.get('from', None) - froms = select._get_display_froms(existingfroms) + froms = select._get_display_froms(existingfroms, asfrom=asfrom) correlate_froms = set(sql._from_objects(*froms)) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 490004e390..0ebcc11465 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -4980,7 +4980,7 @@ class CompoundSelect(SelectBase): INTERSECT_ALL = util.symbol('INTERSECT ALL') def __init__(self, keyword, *selects, **kwargs): - self._should_correlate = kwargs.pop('correlate', False) + self._auto_correlate = kwargs.pop('correlate', False) self.keyword = keyword self.selects = [] @@ -5159,7 +5159,7 @@ class Select(HasPrefixes, SelectBase): :class:`SelectBase` superclass. """ - self._should_correlate = correlate + self._auto_correlate = correlate if distinct is not False: if distinct is True: self._distinct = True @@ -5232,7 +5232,7 @@ class Select(HasPrefixes, SelectBase): return froms - def _get_display_froms(self, existing_froms=None): + def _get_display_froms(self, existing_froms=None, asfrom=False): """Return the full list of 'from' clauses to be displayed. Takes into account a set of existing froms which may be @@ -5258,25 +5258,34 @@ class Select(HasPrefixes, SelectBase): # using a list to maintain ordering froms = [f for f in froms if f not in toremove] - if len(froms) > 1 or self._correlate or self._correlate_except: - if self._correlate: - froms = [f for f in froms if f not in - _cloned_intersection(froms, - self._correlate)] - if self._correlate_except: - froms = [f for f in froms if f in _cloned_intersection(froms, - self._correlate_except)] - if self._should_correlate and existing_froms: - froms = [f for f in froms if f not in - _cloned_intersection(froms, - existing_froms)] - - if not len(froms): - raise exc.InvalidRequestError("Select statement '%s" - "' returned no FROM clauses due to " - "auto-correlation; specify " - "correlate() to control " - "correlation manually." % self) + if self._correlate: + froms = [ + f for f in froms if f not in + _cloned_intersection( + _cloned_intersection(froms, existing_froms or ()), + self._correlate + ) + ] + if self._correlate_except: + froms = [ + f for f in froms if f in + _cloned_intersection( + froms, + self._correlate_except + ) + ] + if self._auto_correlate and existing_froms and len(froms) > 1 and not asfrom: + froms = [ + f for f in froms if f not in + _cloned_intersection(froms, existing_froms) + ] + + if not len(froms): + raise exc.InvalidRequestError("Select statement '%s" + "' returned no FROM clauses due to " + "auto-correlation; specify " + "correlate() to control " + "correlation manually." % self) return froms @@ -5642,7 +5651,7 @@ class Select(HasPrefixes, SelectBase): :ref:`correlated_subqueries` """ - self._should_correlate = False + self._auto_correlate = False if fromclauses and fromclauses[0] is None: self._correlate = () else: @@ -5662,7 +5671,7 @@ class Select(HasPrefixes, SelectBase): :ref:`correlated_subqueries` """ - self._should_correlate = False + self._auto_correlate = False if fromclauses and fromclauses[0] is None: self._correlate_except = () else: @@ -5673,7 +5682,7 @@ class Select(HasPrefixes, SelectBase): """append the given correlation expression to this select() construct.""" - self._should_correlate = False + self._auto_correlate = False self._correlate = set(self._correlate).union( _interpret_as_from(f) for f in fromclause) diff --git a/test/orm/test_query.py b/test/orm/test_query.py index f418d2581e..be5d2b1350 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -194,22 +194,28 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): Address = self.classes.Address self.assert_compile( - select([User]).where(User.id == Address.user_id). - correlate(Address), - "SELECT users.id, users.name FROM users " - "WHERE users.id = addresses.user_id" + select([ + User.name, + select([func.count(Address.id) + ]).where(User.id == Address.user_id).correlate(User)]), + "SELECT users.name, count_1 FROM users, " + "(SELECT count(addresses.id) AS count_1 " + "FROM addresses WHERE users.id = addresses.user_id)" ) def test_correlate_aliased_entity(self): User = self.classes.User Address = self.classes.Address - aa = aliased(Address, name="aa") + uu = aliased(User, name="uu") self.assert_compile( - select([User]).where(User.id == aa.user_id). - correlate(aa), - "SELECT users.id, users.name FROM users " - "WHERE users.id = aa.user_id" + select([ + uu.name, + select([func.count(Address.id) + ]).where(uu.id == Address.user_id).correlate(uu)]), + "SELECT uu.name, count_1 FROM users AS uu, " + "(SELECT count(addresses.id) AS count_1 " + "FROM addresses WHERE addresses.user_id = uu.id)" ) def test_columns_clause_entity(self): diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 3b8aed23f6..22fecf6658 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -451,7 +451,11 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): # intentional again s = s.correlate(t, t2) s2 = select([t, t2, s]) - self.assert_compile(s, "SELECT t.a WHERE t.a = t2.d") + self.assert_compile( + s2, + "SELECT t.a, t.b, t2.c, t2.d, a " + "FROM t, t2, (SELECT t.a AS a WHERE t.a = t2.d)" + ) def test_exists(self): s = select([table1.c.myid]).where(table1.c.myid == 5) @@ -3315,4 +3319,4 @@ class ResultMapTest(fixtures.TestBase): ) is_( comp.result_map['t1_a'][1][2], t1.c.a - ) \ No newline at end of file + ) diff --git a/test/sql/test_generative.py b/test/sql/test_generative.py index e868cbe885..b43761f6ff 100644 --- a/test/sql/test_generative.py +++ b/test/sql/test_generative.py @@ -588,15 +588,24 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): assert orig == str(s) == str(s5) def test_correlated_select(self): - s = select(['*'], t1.c.col1 == t2.c.col1, - from_obj=[t1, t2]).correlate(t2) + s = select( + [func.count(t1.c.col1)], + t1.c.col1 == t2.c.col1, + from_obj=[t1, t2] + ).correlate(t2) + class Vis(CloningVisitor): def visit_select(self, select): select.append_whereclause(t1.c.col2 == 7) - self.assert_compile(Vis().traverse(s), - "SELECT * FROM table1 WHERE table1.col1 = table2.col1 " - "AND table1.col2 = :col2_1") + supers = select([t2, Vis().traverse(s)]) + + self.assert_compile(supers, + "SELECT table2.col1, table2.col2, table2.col3, " + "count_1 FROM table2, " + "(SELECT count(table1.col1) AS count_1 " + "FROM table1 WHERE table1.col1 = table2.col1 " + "AND table1.col2 = :col2_1)") def test_this_thing(self): s = select([t1]).where(t1.c.col1 == 'foo').alias() @@ -619,32 +628,32 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): def test_select_fromtwice(self): t1a = t1.alias() - s = select([1], t1.c.col1 == t1a.c.col1, from_obj=t1a).correlate(t1) + s = select([1], t1.c.col1 == t1a.c.col1, from_obj=t1a) self.assert_compile(s, - 'SELECT 1 FROM table1 AS table1_1 WHERE ' + 'SELECT 1 FROM table1, table1 AS table1_1 WHERE ' 'table1.col1 = table1_1.col1') s = CloningVisitor().traverse(s) self.assert_compile(s, - 'SELECT 1 FROM table1 AS table1_1 WHERE ' + 'SELECT 1 FROM table1, table1 AS table1_1 WHERE ' 'table1.col1 = table1_1.col1') - s = select([t1]).where(t1.c.col1 == 'foo').alias() + s = select([t1]).where(t1.c.col1 == 'foo').correlate(t1).alias() - s2 = select([1], t1.c.col1 == s.c.col1, from_obj=s).correlate(t1) + s2 = select([1], t1.c.col1 == s.c.col1, from_obj=s) self.assert_compile(s2, - 'SELECT 1 FROM (SELECT table1.col1 AS ' - 'col1, table1.col2 AS col2, table1.col3 AS ' - 'col3 FROM table1 WHERE table1.col1 = ' - ':col1_1) AS anon_1 WHERE table1.col1 = ' - 'anon_1.col1') + 'SELECT 1 FROM table1, ' + '(SELECT table1.col1 AS col1, ' + 'table1.col2 AS col2, table1.col3 AS col3 ' + 'WHERE table1.col1 = :col1_1) AS anon_1 ' + 'WHERE table1.col1 = anon_1.col1') s2 = ReplacingCloningVisitor().traverse(s2) self.assert_compile(s2, - 'SELECT 1 FROM (SELECT table1.col1 AS ' - 'col1, table1.col2 AS col2, table1.col3 AS ' - 'col3 FROM table1 WHERE table1.col1 = ' - ':col1_1) AS anon_1 WHERE table1.col1 = ' - 'anon_1.col1') + 'SELECT 1 FROM table1, ' + '(SELECT table1.col1 AS col1, ' + 'table1.col2 AS col2, table1.col3 AS col3 ' + 'WHERE table1.col1 = :col1_1) AS anon_1 ' + 'WHERE table1.col1 = anon_1.col1') class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = 'default' @@ -784,16 +793,15 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): == t2.c.col2, from_obj=[t1, t2])), 'SELECT * FROM table1 AS t1alias, table2 ' 'WHERE t1alias.col1 = table2.col2') - self.assert_compile(vis.traverse(select(['*'], t1.c.col1 - == t2.c.col2, from_obj=[t1, - t2]).correlate(t1)), - 'SELECT * FROM table2 WHERE t1alias.col1 = ' - 'table2.col2') - self.assert_compile(vis.traverse(select(['*'], t1.c.col1 - == t2.c.col2, from_obj=[t1, - t2]).correlate(t2)), - 'SELECT * FROM table1 AS t1alias WHERE ' - 't1alias.col1 = table2.col2') + self.assert_compile(vis.traverse(select([ + t1, select([func.count(t2.c.col2)], + t1.c.col1 == t2.c.col2, from_obj=[t1, + t2]).correlate(t1)])), + 'SELECT t1alias.col1, t1alias.col2, ' + 't1alias.col3, count_1 FROM table1 AS ' + 't1alias, (SELECT count(table2.col2) ' + 'AS count_1 FROM table2 ' + 'WHERE t1alias.col1 = table2.col2)') self.assert_compile(vis.traverse(case([(t1.c.col1 == 5, t1.c.col2)], else_=t1.c.col1)), 'CASE WHEN (t1alias.col1 = :col1_1) THEN ' @@ -836,16 +844,15 @@ class ClauseAdapterTest(fixtures.TestBase, AssertsCompiledSQL): 'SELECT * FROM table1 AS t1alias, table2 ' 'AS t2alias WHERE t1alias.col1 = ' 't2alias.col2') - self.assert_compile(vis.traverse(select(['*'], t1.c.col1 - == t2.c.col2, from_obj=[t1, - t2]).correlate(t1)), - 'SELECT * FROM table2 AS t2alias WHERE ' - 't1alias.col1 = t2alias.col2') - self.assert_compile(vis.traverse(select(['*'], t1.c.col1 - == t2.c.col2, from_obj=[t1, - t2]).correlate(t2)), - 'SELECT * FROM table1 AS t1alias WHERE ' - 't1alias.col1 = t2alias.col2') + self.assert_compile(vis.traverse(select([ + t1, select([func.count(t2.c.col2)], + t1.c.col1 == t2.c.col2, from_obj=[t1, + t2]).correlate(t1)])), + 'SELECT t1alias.col1, t1alias.col2, ' + 't1alias.col3, count_1 FROM table1 AS ' + 't1alias, (SELECT count(t2alias.col2) ' + 'AS count_1 FROM table2 AS t2alias ' + 'WHERE t1alias.col1 = t2alias.col2)') def test_include_exclude(self): m = MetaData()