from sqlalchemy.testing import fixtures, eq_
from sqlalchemy.testing import AssertsCompiledSQL, assert_raises_message
-from sqlalchemy.sql import table, column, select, func, literal, exists, and_
+from sqlalchemy.sql import table, column, select, func, literal, exists, \
+ and_, bindparam
from sqlalchemy.dialects import mssql
from sqlalchemy.engine import default
from sqlalchemy.exc import CompileError
from sqlalchemy.sql.elements import quoted_name
+from sqlalchemy.sql.visitors import cloned_traverse
+
class CTETest(fixtures.TestBase, AssertsCompiledSQL):
"FROM regional_sales AS rs WHERE "
"rs.amount < :amount_2")
+ cloned = cloned_traverse(s, {}, {})
+ self.assert_compile(cloned,
+ "WITH regional_sales AS "
+ "(SELECT orders.region AS region, "
+ "orders.amount AS amount FROM orders) "
+ "SELECT rs.region FROM regional_sales AS rs "
+ "WHERE rs.amount > :amount_1 "
+ "UNION ALL SELECT rs.region "
+ "FROM regional_sales AS rs WHERE "
+ "rs.amount < :amount_2")
+
+ def test_cloned_alias(self):
+ entity = table(
+ 'entity', column('id'), column('employer_id'), column('name'))
+ tag = table('tag', column('tag'), column('entity_id'))
+
+ tags = select([
+ tag.c.entity_id,
+ func.array_agg(tag.c.tag).label('tags'),
+ ]).group_by(tag.c.entity_id).cte('unaliased_tags')
+
+ entity_tags = tags.alias(name='entity_tags')
+ employer_tags = tags.alias(name='employer_tags')
+
+ q = (
+ select([entity.c.name])
+ .select_from(
+ entity
+ .outerjoin(entity_tags, tags.c.entity_id == entity.c.id)
+ .outerjoin(employer_tags,
+ tags.c.entity_id == entity.c.employer_id)
+ )
+ .where(entity_tags.c.tags.op('@>')(bindparam('tags')))
+ .where(employer_tags.c.tags.op('@>')(bindparam('tags')))
+ )
+
+ self.assert_compile(
+ q,
+ 'WITH unaliased_tags AS '
+ '(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags '
+ 'FROM tag GROUP BY tag.entity_id)'
+ ' SELECT entity.name '
+ 'FROM entity '
+ 'LEFT OUTER JOIN unaliased_tags AS entity_tags ON '
+ 'unaliased_tags.entity_id = entity.id '
+ 'LEFT OUTER JOIN unaliased_tags AS employer_tags ON '
+ 'unaliased_tags.entity_id = entity.employer_id '
+ 'WHERE (entity_tags.tags @> :tags) AND '
+ '(employer_tags.tags @> :tags)'
+ )
+
+ cloned = q.params(tags=['tag1', 'tag2'])
+ self.assert_compile(
+ cloned,
+ 'WITH unaliased_tags AS '
+ '(SELECT tag.entity_id AS entity_id, array_agg(tag.tag) AS tags '
+ 'FROM tag GROUP BY tag.entity_id)'
+ ' SELECT entity.name '
+ 'FROM entity '
+ 'LEFT OUTER JOIN unaliased_tags AS entity_tags ON '
+ 'unaliased_tags.entity_id = entity.id '
+ 'LEFT OUTER JOIN unaliased_tags AS employer_tags ON '
+ 'unaliased_tags.entity_id = entity.employer_id '
+ 'WHERE (entity_tags.tags @> :tags) AND '
+ '(employer_tags.tags @> :tags)')
+
def test_reserved_quote(self):
orders = table('orders',
column('order'),
"SELECT sum(t.n) AS sum_1 FROM t"
)
+ def test_aliased_cte_w_union(self):
+ t = select([func.values(1).label("n")]).\
+ cte("t", recursive=True).alias('foo')
+ t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100))
+ s = select([func.sum(t.c.n)])
+
+ from sqlalchemy.sql.visitors import cloned_traverse
+ cloned = cloned_traverse(s, {}, {})
+
+ self.assert_compile(
+ cloned,
+ "WITH RECURSIVE foo(n) AS (SELECT values(:values_1) AS n "
+ "UNION ALL SELECT foo.n + :n_1 AS anon_1 FROM t AS foo "
+ "WHERE foo.n < :n_2) SELECT sum(foo.n) AS sum_1 FROM foo"
+ )
+
def test_text(self):
clause = text(
"select * from table where foo=:bar",