--- /dev/null
+from .. import fixtures, config
+from ..assertions import eq_
+
+from sqlalchemy import Integer, String, select
+from sqlalchemy import ForeignKey
+from sqlalchemy import testing
+
+from ..schema import Table, Column
+
+
+class CTETest(fixtures.TablesTest):
+ __backend__ = True
+ __requires__ = 'ctes',
+
+ run_inserts = 'each'
+ run_deletes = 'each'
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("some_table", metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50)),
+ Column("parent_id", ForeignKey("some_table.id")))
+
+ Table("some_other_table", metadata,
+ Column('id', Integer, primary_key=True),
+ Column('data', String(50)),
+ Column("parent_id", Integer))
+
+ @classmethod
+ def insert_data(cls):
+ config.db.execute(
+ cls.tables.some_table.insert(),
+ [
+ {"id": 1, "data": "d1", "parent_id": None},
+ {"id": 2, "data": "d2", "parent_id": 1},
+ {"id": 3, "data": "d3", "parent_id": 1},
+ {"id": 4, "data": "d4", "parent_id": 3},
+ {"id": 5, "data": "d5", "parent_id": 3}
+ ]
+ )
+
+ def test_select_nonrecursive_round_trip(self):
+ some_table = self.tables.some_table
+
+ with config.db.connect() as conn:
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])).cte("some_cte")
+ result = conn.execute(
+ select([cte.c.data]).where(cte.c.data.in_(["d4", "d5"]))
+ )
+ eq_(result.fetchall(), [("d4", )])
+
+ def test_select_recursive_round_trip(self):
+ some_table = self.tables.some_table
+
+ with config.db.connect() as conn:
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])).cte(
+ "some_cte", recursive=True)
+
+ cte_alias = cte.alias("c1")
+ st1 = some_table.alias()
+ # note that SQL Server requires this to be UNION ALL,
+ # can't be UNION
+ cte = cte.union_all(
+ select([st1]).where(st1.c.id == cte_alias.c.parent_id)
+ )
+ result = conn.execute(
+ select([cte.c.data]).where(
+ cte.c.data != "d2").order_by(cte.c.data.desc())
+ )
+ eq_(
+ result.fetchall(),
+ [('d4',), ('d3',), ('d3',), ('d1',), ('d1',), ('d1',)]
+ )
+
+ def test_insert_from_select_round_trip(self):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.insert().from_select(
+ ["id", "data", "parent_id"],
+ select([cte])
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)]
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.update_from
+ def test_update_from_round_trip(self):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ conn.execute(
+ some_other_table.insert().from_select(
+ ['id', 'data', 'parent_id'],
+ select([some_table])
+ )
+ )
+
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.update().values(parent_id=5).where(
+ some_other_table.c.data == cte.c.data
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None), (2, "d2", 5),
+ (3, "d3", 5), (4, "d4", 5), (5, "d5", 3)
+ ]
+ )
+
+ @testing.requires.ctes_with_update_delete
+ @testing.requires.delete_from
+ def test_delete_from_round_trip(self):
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ conn.execute(
+ some_other_table.insert().from_select(
+ ['id', 'data', 'parent_id'],
+ select([some_table])
+ )
+ )
+
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data == cte.c.data
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None), (5, "d5", 3)
+ ]
+ )
+
+ @testing.requires.ctes_with_update_delete
+ def test_delete_scalar_subq_round_trip(self):
+
+ some_table = self.tables.some_table
+ some_other_table = self.tables.some_other_table
+
+ with config.db.connect() as conn:
+ conn.execute(
+ some_other_table.insert().from_select(
+ ['id', 'data', 'parent_id'],
+ select([some_table])
+ )
+ )
+
+ cte = select([some_table]).where(
+ some_table.c.data.in_(["d2", "d3", "d4"])
+ ).cte("some_cte")
+ conn.execute(
+ some_other_table.delete().where(
+ some_other_table.c.data ==
+ select([cte.c.data]).where(
+ cte.c.id == some_other_table.c.id)
+ )
+ )
+ eq_(
+ conn.execute(
+ select([some_other_table]).order_by(some_other_table.c.id)
+ ).fetchall(),
+ [
+ (1, "d1", None), (5, "d5", 3)
+ ]
+ )
checkparams={"name_1": "bar"}
)
+ def test_insert_from_select_cte_follows_insert_one(self):
+ dialect = default.DefaultDialect()
+ dialect.cte_follows_insert = True
+
+ table1 = self.tables.mytable
+
+ cte = select([table1.c.name]).where(table1.c.name == 'bar').cte()
+
+ sel = select([table1.c.myid, table1.c.name]).where(
+ table1.c.name == cte.c.name)
+
+ ins = self.tables.myothertable.insert().\
+ from_select(("otherid", "othername"), sel)
+ self.assert_compile(
+ ins,
+ "INSERT INTO myothertable (otherid, othername) "
+ "WITH anon_1 AS "
+ "(SELECT mytable.name AS name FROM mytable "
+ "WHERE mytable.name = :name_1) "
+ "SELECT mytable.myid, mytable.name FROM mytable, anon_1 "
+ "WHERE mytable.name = anon_1.name",
+ checkparams={"name_1": "bar"},
+ dialect=dialect
+ )
+
def test_insert_from_select_cte_two(self):
table1 = self.tables.mytable
"SELECT c.myid, c.name, c.description FROM c"
)
+ def test_insert_from_select_cte_follows_insert_two(self):
+ dialect = default.DefaultDialect()
+ dialect.cte_follows_insert = True
+ table1 = self.tables.mytable
+
+ cte = table1.select().cte("c")
+ stmt = cte.select()
+ ins = table1.insert().from_select(table1.c, stmt)
+
+ self.assert_compile(
+ ins,
+ "INSERT INTO mytable (myid, name, description) "
+ "WITH c AS (SELECT mytable.myid AS myid, mytable.name AS name, "
+ "mytable.description AS description FROM mytable) "
+ "SELECT c.myid, c.name, c.description FROM c",
+ dialect=dialect
+ )
+
def test_insert_from_select_select_alt_ordering(self):
table1 = self.tables.mytable
sel = select([table1.c.name, table1.c.myid]).where(