From: Mike Bayer Date: Tue, 6 Mar 2018 02:36:18 +0000 (-0500) Subject: Clone _cte_alias instead of assigning "self" X-Git-Tag: rel_1_3_0b1~235^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9a13f007e2342def94cc7362eeadd5ec8c988340;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Clone _cte_alias instead of assigning "self" Fixed bug in :class:.`CTE` construct along the same lines as that of :ticket:`4204` where a :class:`.CTE` that was aliased would not copy itself correctly during a "clone" operation as is frequent within the ORM as well as when using the :meth:`.ClauseElement.params` method. Change-Id: Id68d72dd244dedfc7bd6116c9a5123c51a55ea20 Fixes: #4210 --- diff --git a/doc/build/changelog/unreleased_12/4210.rst b/doc/build/changelog/unreleased_12/4210.rst new file mode 100644 index 0000000000..04e7e86659 --- /dev/null +++ b/doc/build/changelog/unreleased_12/4210.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, sql + :tickets: 4210 + :versions: 1.3.0b1 + + Fixed bug in :class:.`CTE` construct along the same lines as that of + :ticket:`4204` where a :class:`.CTE` that was aliased would not copy itself + correctly during a "clone" operation as is frequent within the ORM as well + as when using the :meth:`.ClauseElement.params` method. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 29b8836dd9..04f6c086d1 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1391,7 +1391,7 @@ class CTE(Generative, HasSuffixes, Alias): def _copy_internals(self, clone=_clone, **kw): super(CTE, self)._copy_internals(clone, **kw) if self._cte_alias is not None: - self._cte_alias = self + self._cte_alias = clone(self._cte_alias, **kw) self._restates = frozenset([ clone(elem, **kw) for elem in self._restates ]) diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index af9c8ceb6f..2c19ed0324 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -1,10 +1,13 @@ from sqlalchemy.testing import fixtures, eq_ from sqlalchemy.testing import AssertsCompiledSQL, assert_raises_message -from sqlalchemy.sql import table, column, select, func, literal, exists, and_ +from sqlalchemy.sql import table, column, select, func, literal, exists, \ + and_, bindparam from sqlalchemy.dialects import mssql from sqlalchemy.engine import default from sqlalchemy.exc import CompileError from sqlalchemy.sql.elements import quoted_name +from sqlalchemy.sql.visitors import cloned_traverse + class CTETest(fixtures.TestBase, AssertsCompiledSQL): @@ -436,6 +439,72 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): "FROM regional_sales AS rs WHERE " "rs.amount < :amount_2") + cloned = cloned_traverse(s, {}, {}) + self.assert_compile(cloned, + "WITH regional_sales AS " + "(SELECT orders.region AS region, " + "orders.amount AS amount FROM orders) " + "SELECT rs.region FROM regional_sales AS rs " + "WHERE rs.amount > :amount_1 " + "UNION ALL SELECT rs.region " + "FROM regional_sales AS rs WHERE " + "rs.amount < :amount_2") + + def test_cloned_alias(self): + entity = table( + 'entity', column('id'), column('employer_id'), column('name')) + tag = table('tag', column('tag'), column('entity_id')) + + tags = select([ + tag.c.entity_id, + func.array_agg(tag.c.tag).label('tags'), + ]).group_by(tag.c.entity_id).cte('unaliased_tags') + + entity_tags = tags.alias(name='entity_tags') + employer_tags = tags.alias(name='employer_tags') + + q = ( + select([entity.c.name]) + .select_from( + entity + .outerjoin(entity_tags, tags.c.entity_id == entity.c.id) + .outerjoin(employer_tags, + tags.c.entity_id == entity.c.employer_id) + ) + .where(entity_tags.c.tags.op('@>')(bindparam('tags'))) + .where(employer_tags.c.tags.op('@>')(bindparam('tags'))) + ) + + self.assert_compile( + q, + 'WITH unaliased_tags AS ' + '(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags ' + 'FROM tag GROUP BY tag.entity_id)' + ' SELECT entity.name ' + 'FROM entity ' + 'LEFT OUTER JOIN unaliased_tags AS entity_tags ON ' + 'unaliased_tags.entity_id = entity.id ' + 'LEFT OUTER JOIN unaliased_tags AS employer_tags ON ' + 'unaliased_tags.entity_id = entity.employer_id ' + 'WHERE (entity_tags.tags @> :tags) AND ' + '(employer_tags.tags @> :tags)' + ) + + cloned = q.params(tags=['tag1', 'tag2']) + self.assert_compile( + cloned, + 'WITH unaliased_tags AS ' + '(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags ' + 'FROM tag GROUP BY tag.entity_id)' + ' SELECT entity.name ' + 'FROM entity ' + 'LEFT OUTER JOIN unaliased_tags AS entity_tags ON ' + 'unaliased_tags.entity_id = entity.id ' + 'LEFT OUTER JOIN unaliased_tags AS employer_tags ON ' + 'unaliased_tags.entity_id = entity.employer_id ' + 'WHERE (entity_tags.tags @> :tags) AND ' + '(employer_tags.tags @> :tags)') + def test_reserved_quote(self): orders = table('orders', column('order'), diff --git a/test/sql/test_generative.py b/test/sql/test_generative.py index 9474560ff1..145b2da3cf 100644 --- a/test/sql/test_generative.py +++ b/test/sql/test_generative.py @@ -497,6 +497,22 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT sum(t.n) AS sum_1 FROM t" ) + def test_aliased_cte_w_union(self): + t = select([func.values(1).label("n")]).\ + cte("t", recursive=True).alias('foo') + t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100)) + s = select([func.sum(t.c.n)]) + + from sqlalchemy.sql.visitors import cloned_traverse + cloned = cloned_traverse(s, {}, {}) + + self.assert_compile( + cloned, + "WITH RECURSIVE foo(n) AS (SELECT values(:values_1) AS n " + "UNION ALL SELECT foo.n + :n_1 AS anon_1 FROM t AS foo " + "WHERE foo.n < :n_2) SELECT sum(foo.n) AS sum_1 FROM foo" + ) + def test_text(self): clause = text( "select * from table where foo=:bar",