]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Clone _cte_alias instead of assigning "self"
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Mar 2018 02:36:18 +0000 (21:36 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Mar 2018 02:37:12 +0000 (21:37 -0500)
Fixed bug in :class:.`CTE` construct along the same lines as that of
:ticket:`4204` where a :class:`.CTE` that was aliased would not copy itself
correctly during a "clone" operation as is frequent within the ORM as well
as when using the :meth:`.ClauseElement.params` method.

Change-Id: Id68d72dd244dedfc7bd6116c9a5123c51a55ea20
Fixes: #4210
doc/build/changelog/unreleased_12/4210.rst [new file with mode: 0644]
lib/sqlalchemy/sql/selectable.py
test/sql/test_cte.py
test/sql/test_generative.py

diff --git a/doc/build/changelog/unreleased_12/4210.rst b/doc/build/changelog/unreleased_12/4210.rst
new file mode 100644 (file)
index 0000000..04e7e86
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 4210
+    :versions: 1.3.0b1
+
+    Fixed bug in :class:.`CTE` construct along the same lines as that of
+    :ticket:`4204` where a :class:`.CTE` that was aliased would not copy itself
+    correctly during a "clone" operation as is frequent within the ORM as well
+    as when using the :meth:`.ClauseElement.params` method.
index 29b8836dd982bbc349083cf7d2fa8cfffc91cf25..04f6c086d17c907914ac12cd59e00482a77a7f02 100644 (file)
@@ -1391,7 +1391,7 @@ class CTE(Generative, HasSuffixes, Alias):
     def _copy_internals(self, clone=_clone, **kw):
         super(CTE, self)._copy_internals(clone, **kw)
         if self._cte_alias is not None:
-            self._cte_alias = self
+            self._cte_alias = clone(self._cte_alias, **kw)
         self._restates = frozenset([
             clone(elem, **kw) for elem in self._restates
         ])
index af9c8ceb6fba3bccaffd86c12c7f122b110a4f70..2c19ed0324557609bc98b599adc7b33224e9ce02 100644 (file)
@@ -1,10 +1,13 @@
 from sqlalchemy.testing import fixtures, eq_
 from sqlalchemy.testing import AssertsCompiledSQL, assert_raises_message
-from sqlalchemy.sql import table, column, select, func, literal, exists, and_
+from sqlalchemy.sql import table, column, select, func, literal, exists, \
+    and_, bindparam
 from sqlalchemy.dialects import mssql
 from sqlalchemy.engine import default
 from sqlalchemy.exc import CompileError
 from sqlalchemy.sql.elements import quoted_name
+from sqlalchemy.sql.visitors import cloned_traverse
+
 
 class CTETest(fixtures.TestBase, AssertsCompiledSQL):
 
@@ -436,6 +439,72 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL):
                             "FROM regional_sales AS rs WHERE "
                             "rs.amount < :amount_2")
 
+        cloned = cloned_traverse(s, {}, {})
+        self.assert_compile(cloned,
+                            "WITH regional_sales AS "
+                            "(SELECT orders.region AS region, "
+                            "orders.amount AS amount FROM orders) "
+                            "SELECT rs.region FROM regional_sales AS rs "
+                            "WHERE rs.amount > :amount_1 "
+                            "UNION ALL SELECT rs.region "
+                            "FROM regional_sales AS rs WHERE "
+                            "rs.amount < :amount_2")
+
+    def test_cloned_alias(self):
+        entity = table(
+            'entity', column('id'), column('employer_id'), column('name'))
+        tag = table('tag', column('tag'), column('entity_id'))
+
+        tags = select([
+            tag.c.entity_id,
+            func.array_agg(tag.c.tag).label('tags'),
+        ]).group_by(tag.c.entity_id).cte('unaliased_tags')
+
+        entity_tags = tags.alias(name='entity_tags')
+        employer_tags = tags.alias(name='employer_tags')
+
+        q = (
+            select([entity.c.name])
+            .select_from(
+                entity
+                .outerjoin(entity_tags, tags.c.entity_id == entity.c.id)
+                .outerjoin(employer_tags,
+                           tags.c.entity_id == entity.c.employer_id)
+            )
+            .where(entity_tags.c.tags.op('@>')(bindparam('tags')))
+            .where(employer_tags.c.tags.op('@>')(bindparam('tags')))
+        )
+
+        self.assert_compile(
+            q,
+            'WITH unaliased_tags AS '
+            '(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags '
+            'FROM tag GROUP BY tag.entity_id)'
+            ' SELECT entity.name '
+            'FROM entity '
+            'LEFT OUTER JOIN unaliased_tags AS entity_tags ON '
+            'unaliased_tags.entity_id = entity.id '
+            'LEFT OUTER JOIN unaliased_tags AS employer_tags ON '
+            'unaliased_tags.entity_id = entity.employer_id '
+            'WHERE (entity_tags.tags @> :tags) AND '
+            '(employer_tags.tags @> :tags)'
+        )
+
+        cloned = q.params(tags=['tag1', 'tag2'])
+        self.assert_compile(
+            cloned,
+            'WITH unaliased_tags AS '
+            '(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags '
+            'FROM tag GROUP BY tag.entity_id)'
+            ' SELECT entity.name '
+            'FROM entity '
+            'LEFT OUTER JOIN unaliased_tags AS entity_tags ON '
+            'unaliased_tags.entity_id = entity.id '
+            'LEFT OUTER JOIN unaliased_tags AS employer_tags ON '
+            'unaliased_tags.entity_id = entity.employer_id '
+            'WHERE (entity_tags.tags @> :tags) AND '
+            '(employer_tags.tags @> :tags)')
+
     def test_reserved_quote(self):
         orders = table('orders',
                        column('order'),
index 9474560ff1762b320cb040e5a0b2adcc4425fe6c..145b2da3cf83b80af40e55b1a15c59fcd997ec4b 100644 (file)
@@ -497,6 +497,22 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
                             "SELECT sum(t.n) AS sum_1 FROM t"
                             )
 
+    def test_aliased_cte_w_union(self):
+        t = select([func.values(1).label("n")]).\
+            cte("t", recursive=True).alias('foo')
+        t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100))
+        s = select([func.sum(t.c.n)])
+
+        from sqlalchemy.sql.visitors import cloned_traverse
+        cloned = cloned_traverse(s, {}, {})
+
+        self.assert_compile(
+            cloned,
+            "WITH RECURSIVE foo(n) AS (SELECT values(:values_1) AS n "
+            "UNION ALL SELECT foo.n + :n_1 AS anon_1 FROM t AS foo "
+            "WHERE foo.n < :n_2) SELECT sum(foo.n) AS sum_1 FROM foo"
+        )
+
     def test_text(self):
         clause = text(
             "select * from table where foo=:bar",