From: Greg Jarzab Date: Thu, 11 Sep 2025 04:52:45 +0000 (-0500) Subject: Fixes: #4950 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=aa1412628b5a9f481cdb7264f8b0a8942ece1997;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fixes: #4950 add CreateTableAs for default dialect and SQLite. add Select.into constructor for CreateTableAs. --- diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 0f9cef6004..e87d165835 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1737,6 +1737,23 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return colspec + def visit_create_table_as(self, element, **kw): + prep = self.preparer + select_sql = self.sql_compiler.process( + element.selectable, literal_binds=True + ) + + parts = [ + "CREATE", + "TEMPORARY" if element.temporary else None, + "TABLE", + "IF NOT EXISTS" if element.if_not_exists else None, + prep.format_table(element.table), + "AS", + select_sql, + ] + return " ".join(p for p in parts if p) + def visit_primary_key_constraint(self, constraint, **kw): # for columns with sqlite_autoincrement=True, # the PRIMARY KEY constraint can only be inline diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index e95eaa5918..0e05a7c5ad 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -6937,6 +6937,25 @@ class DDLCompiler(Compiled): text += "\n)%s\n\n" % self.post_create_table(table) return text + def visit_create_table_as(self, element, **kw): + """Default CTAS emission. + + Render a generic CREATE TABLE AS form and let dialects override to + add features (TEMPORARY, IF NOT EXISTS, SELECT INTO on MSSQL, etc.). + + Keep **bind parameters** in the inner SELECT (no literal_binds) + """ + # target identifier (schema-qualified if present) + qualified = self.preparer.format_table(element.table) + + # inner SELECT — keep binds so DDL vs DML + # differences are handled by backends + inner_kw = dict(kw) + inner_kw.pop("literal_binds", None) + select_sql = self.sql_compiler.process(element.selectable, **inner_kw) + + return f"CREATE TABLE {qualified} AS {select_sql}" + def visit_create_column(self, create, first_pk=False, **kw): column = create.element diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 8bd37454e1..7c0c57b1c5 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -32,6 +32,9 @@ from .base import _generative from .base import Executable from .base import SchemaVisitor from .elements import ClauseElement +from .schema import Table +from .selectable import Selectable +from .selectable import TableClause from .. import exc from .. import util from ..util import topological @@ -47,8 +50,6 @@ if typing.TYPE_CHECKING: from .schema import Index from .schema import SchemaItem from .schema import Sequence as Sequence # noqa: F401 - from .schema import Table - from .selectable import TableClause from ..engine.base import Connection from ..engine.interfaces import CacheStats from ..engine.interfaces import CompiledCacheType @@ -544,6 +545,91 @@ class CreateTable(_CreateBase["Table"]): self.include_foreign_key_constraints = include_foreign_key_constraints +class CreateTableAs(ExecutableDDLElement): + """Represent a CREATE TABLE ... AS (CTAS) statement. + + This creates a new table directly from the output of a SELECT. + The set of columns in the new table is derived from the + SELECT list; constraints, indexes, and defaults are not copied. + + :param selectable: :class:`_sql.Selectable` + The SELECT (or other selectable) providing the columns and rows. + + :param target: str | :class:`_sql.TableClause` + Table name or object. If passed as a string, it must be + unqualified; use the ``schema`` argument for qualification. + + :param schema: str, optional + Schema or owner name. If both ``schema`` and the target object + specify a schema, they must match. + + :param temporary: bool, default False. + If True, render ``TEMPORARY`` (PostgreSQL, MySQL, SQLite), or + a ``#`` temporary table on SQL Server. Dialects that do + not support this option will raise :class:`.CompileError`. + + :param if_not_exists: bool, default False. + If True, render ``IF NOT EXISTS`` where supported + (PostgreSQL, MySQL, SQLite). Dialects that do not support this + option will raise :class:`.CompileError`. + """ + + __visit_name__ = "create_table_as" + inherit_cache = False + + def __init__( + self, + selectable: Selectable, + element: Union[str, TableClause], + *, + schema: Optional[str] = None, + temporary: bool = False, + if_not_exists: bool = False, + ): + if isinstance(element, TableClause): + t_name = element.name + t_schema = element.schema + + if not t_name or not str(t_name).strip(): + raise exc.ArgumentError("Table name must be non-empty") + + if ( + schema is not None + and t_schema is not None + and schema != t_schema + ): + raise exc.ArgumentError( + f"Conflicting schema: target={t_schema!r}, " + f"schema={schema!r}" + ) + final_schema = ( + schema + if (schema is not None and t_schema is None) + else t_schema + ) + elif isinstance(element, str): + if not element.strip(): + raise exc.ArgumentError("Table name must be non-empty") + if "." in element: + raise exc.ArgumentError( + "Target string must be unqualified (use schema=)." + ) + t_name = element + final_schema = schema + else: + raise exc.ArgumentError("target must be a string, TableClause") + + self.table = TableClause(t_name, schema=final_schema) + self.schema = final_schema + self.selectable = selectable + self.temporary = bool(temporary) + self.if_not_exists = bool(if_not_exists) + + @property + def generated_table(self) -> TableClause: + return self.table + + class _DropView(_DropBase["Table"]): """Semi-public 'DROP VIEW' construct. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 4d72377c2b..3660557bb1 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -138,6 +138,7 @@ if TYPE_CHECKING: from .base import ReadOnlyColumnCollection from .cache_key import _CacheKeyTraversalType from .compiler import SQLCompiler + from .ddl import CreateTableAs from .dml import Delete from .dml import Update from .elements import BinaryExpression @@ -148,6 +149,7 @@ if TYPE_CHECKING: from .functions import Function from .schema import ForeignKey from .schema import ForeignKeyConstraint + from .schema import Table from .sqltypes import TableValueType from .type_api import TypeEngine from .visitors import _CloneCallableType @@ -6823,6 +6825,24 @@ class Select( """ return CompoundSelect._create_intersect_all(self, *other) + def into( + self, + target: Union[str, TableClause, Table], + *, + schema: Optional[str] = None, + temporary: bool = False, + if_not_exists: bool = False, + ) -> "CreateTableAs": + from .ddl import CreateTableAs + + return CreateTableAs( + self, + target, + schema=schema, + temporary=temporary, + if_not_exists=if_not_exists, + ) + class ScalarSelect( roles.InElementRole, Generative, GroupedElement, ColumnElement[_T] diff --git a/test/dialect/sqlite/test_dialect.py b/test/dialect/sqlite/test_dialect.py index 27392fc079..addad30be9 100644 --- a/test/dialect/sqlite/test_dialect.py +++ b/test/dialect/sqlite/test_dialect.py @@ -3,7 +3,9 @@ import os from sqlalchemy import and_ +from sqlalchemy import bindparam from sqlalchemy import Column +from sqlalchemy import column from sqlalchemy import Computed from sqlalchemy import create_engine from sqlalchemy import DefaultClause @@ -11,12 +13,14 @@ from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import func from sqlalchemy import inspect +from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import pool from sqlalchemy import schema from sqlalchemy import select from sqlalchemy import sql from sqlalchemy import Table +from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import text from sqlalchemy import types as sqltypes @@ -26,6 +30,8 @@ from sqlalchemy.dialects.sqlite import pysqlite as pysqlite_dialect from sqlalchemy.engine.url import make_url from sqlalchemy.schema import CreateTable from sqlalchemy.schema import FetchedValue +from sqlalchemy.sql.ddl import CreateTableAs +from sqlalchemy.sql.ddl import DropTable from sqlalchemy.testing import assert_raises from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import AssertsExecutionResults @@ -530,6 +536,319 @@ class AttachedDBTest(fixtures.TablesTest): eq_(row._mapping["name"], "foo") +class CreateTableAsDDLTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = sqlite.dialect() + + @testing.fixture + def src_table(self): + return table("src", column("id"), column("name")) + + @testing.fixture + def src_two_tables(self): + a = table("a", column("id"), column("name")) + b = table("b", column("id"), column("name")) + return a, b + + def test_schema_main(self, src_table): + src = src_table + stmt = CreateTableAs( + select(src.c.id).select_from(src), + "dst", + schema="main", + ) + self.assert_compile( + stmt, + "CREATE TABLE main.dst AS SELECT src.id FROM src", + ) + + def test_temporary_no_schema(self, src_table): + src = src_table + stmt = CreateTableAs( + select(src.c.id, src.c.name).select_from(src), + "dst", + temporary=True, + ) + self.assert_compile( + stmt, + "CREATE TEMPORARY TABLE dst AS " + "SELECT src.id, src.name FROM src", + ) + + def test_select_shape_where_order_limit(self, src_table): + src = src_table + sel = ( + select(src.c.id, src.c.name) + .select_from(src) + .where(src.c.id > literal(10)) + .order_by(src.c.name) + .limit(5) + .offset(0) + ) + stmt = CreateTableAs(sel, "dst") + self.assert_compile( + stmt, + "CREATE TABLE dst AS " + "SELECT src.id, src.name FROM src " + "WHERE src.id > 10 ORDER BY src.name LIMIT 5 OFFSET 0", + ) + + def test_inline_binds(self, src_table): + src = src_table + sel = select( + literal(1).label("x"), literal("a").label("y") + ).select_from(src) + stmt = CreateTableAs(sel, "dst") + self.assert_compile( + stmt, + "CREATE TABLE dst AS SELECT 1 AS x, 'a' AS y FROM src", + ) + + def test_explicit_temp_schema_without_keyword(self, src_table): + # When not using temporary but schema is temp (any case), qualify + src = src_table + stmt = CreateTableAs( + select(src.c.id).select_from(src), + "dst", + schema="TEMP", + ) + self.assert_compile( + stmt, + 'CREATE TABLE "TEMP".dst AS SELECT src.id FROM src', + ) + + def test_if_not_exists(self, src_table): + src = src_table + stmt = CreateTableAs( + select(src.c.id, src.c.name).select_from(src), + "dst", + if_not_exists=True, + ) + self.assert_compile( + stmt, + "CREATE TABLE IF NOT EXISTS dst AS " + "SELECT src.id, src.name FROM src", + ) + + def test_union_all_smoke(self, src_two_tables): + # Proves CTAS wraps a UNION ALL and preserves compound ordering. + a, b = src_two_tables + u = ( + select(a.c.id) + .select_from(a) + .union_all(select(b.c.id).select_from(b)) + .order_by("id") # order-by on the compound + .limit(3) + ) + stmt = CreateTableAs(u, "dst") + self.assert_compile( + stmt, + "CREATE TABLE dst AS " + "SELECT a.id FROM a UNION ALL SELECT b.id FROM b " + "ORDER BY id LIMIT 3 OFFSET 0", + ) + + def test_cte_smoke(self, src_two_tables): + # Proves CTAS works with a WITH-CTE wrapper and labeled column. + a, _ = src_two_tables + cte = select(a.c.id.label("aid")).select_from(a).cte("u") + stmt = CreateTableAs(select(cte.c.aid), "dst") + self.assert_compile( + stmt, + "CREATE TABLE dst AS " + "WITH u AS (SELECT a.id AS aid FROM a) " + "SELECT u.aid FROM u", + ) + + def test_union_all_with_inlined_literals_smoke(self, src_two_tables): + # Proves literal_binds=True behavior applies across branches. + a, b = src_two_tables + u = ( + select(literal(1).label("x")) + .select_from(a) + .union_all(select(literal("b").label("x")).select_from(b)) + ) + stmt = CreateTableAs(u, "dst") + self.assert_compile( + stmt, + "CREATE TABLE dst AS " + "SELECT 1 AS x FROM a UNION ALL SELECT 'b' AS x FROM b", + ) + + +class CreateTableAsSQLiteBehavior(fixtures.TestBase): + __only_on__ = "sqlite" + __backend__ = True + + @testing.fixture + def ctas_manager(self, connection): + """Executes CreateTableAs and drops them after the test""" + + created_tables = [] + + def execute_ctas(stmt: CreateTableAs): + connection.execute(stmt) + created_tables.append(stmt.generated_table) + return stmt + + yield execute_ctas + + for t in created_tables: + connection.execute(DropTable(t, if_exists=True)) + + @testing.fixture + def source_table(self, connection): + connection.exec_driver_sql( + """ + CREATE TABLE IF NOT EXISTS src ( + id INTEGER PRIMARY KEY, + name TEXT + )""" + ) + connection.exec_driver_sql( + "INSERT INTO src (name) VALUES ('a'), ('b')" + ) + yield table("src", column("id"), column("name")) + connection.exec_driver_sql("DROP TABLE IF EXISTS src") + + @testing.fixture + def seeded_tables(self, connection): + connection.exec_driver_sql("CREATE TABLE a (id INTEGER)") + connection.exec_driver_sql("CREATE TABLE b (id INTEGER)") + + def seed(a_values, b_values): + if a_values: + stmt_a = text("INSERT INTO a (id) VALUES (:v)") + connection.execute(stmt_a, [{"v": v} for v in a_values]) + if b_values: + stmt_b = text("INSERT INTO b (id) VALUES (:v)") + connection.execute(stmt_b, [{"v": v} for v in b_values]) + + yield seed + + connection.exec_driver_sql("DROP TABLE a") + connection.exec_driver_sql("DROP TABLE b") + + def test_create_table_as_creates_table_and_copies_rows( + self, connection, source_table, ctas_manager + ): + src = source_table + stmt = CreateTableAs( + select(src.c.id, src.c.name).select_from(src), + "dst", + ) + ctas_manager(stmt) + + insp = inspect(connection) + cols = insp.get_columns("dst") + assert [c["name"] for c in cols] == ["id", "name"] + + # In SQLite CREATE TABLE AS does NOT carry over PK/constraints + pk = insp.get_pk_constraint("dst")["constrained_columns"] + assert pk == [] + + # data copied + count = connection.exec_driver_sql("SELECT COUNT(*) FROM dst").scalar() + assert count == 2 + + def test_if_not_exists_does_not_error( + self, connection, source_table, ctas_manager + ): + src = source_table + stmt = CreateTableAs( + select(src.c.id).select_from(src), + "dst", + if_not_exists=True, + ) + # first run creates; second run should not error + ctas_manager(stmt) + ctas_manager(stmt) + + exists = connection.exec_driver_sql( + "SELECT name FROM sqlite_master WHERE type='table' AND name='dst'" + ).fetchall() + assert exists + + def test_temporary_with_temp_schema_ok( + self, connection, source_table, ctas_manager + ): + src = source_table + stmt = CreateTableAs( + select(src.c.id).select_from(src), + "dst_tmp", + temporary=True, + schema="temp", # accepted; still emits CREATE TEMPORARY TABLE ... + ) + ctas_manager(stmt) + + # verify it was created as a temp table + assert ( + connection.exec_driver_sql( + "SELECT name FROM sqlite_temp_master " + "WHERE type='table' AND name='dst_tmp'" + ).fetchone() + is not None + ) + + def test_literal_inlining_inside_select( + self, connection, source_table, ctas_manager + ): + src = source_table + sel = select( + (src.c.id + 1).label("id2"), + literal("x").label("tag"), + ).select_from(src) + + stmt = CreateTableAs(sel, "dst2") + ctas_manager(stmt) + + rows = connection.exec_driver_sql( + "SELECT COUNT(*), MIN(tag), MAX(tag) FROM dst2" + ).fetchone() + assert rows[0] == 2 and rows[1] == "x" and rows[2] == "x" + + def test_create_table_as_with_bind_param_executes( + self, connection, source_table, ctas_manager + ): + src = source_table + sel = ( + select(src.c.id, src.c.name) + .select_from(src) + .where(src.c.name == bindparam("p", value="a")) + ) + + stmt = CreateTableAs(sel, "dst_bind") + ctas_manager(stmt) + + rows = connection.exec_driver_sql( + "SELECT COUNT(*), MIN(name), MAX(name) FROM dst_bind" + ).fetchone() + assert rows[0] == 1 and rows[1] == "a" and rows[2] == "a" + + def test_compound_select_smoke( + self, connection, seeded_tables, ctas_manager + ): + # UNION ALL + ORDER/LIMIT survives inside CTAS + seeded_tables(a_values=[1, 3], b_values=[2, 4]) + + sel = ( + select(text("id")) + .select_from(text("a")) + .union_all(select(text("id")).select_from(text("b"))) + .order_by(text("id")) + .limit(3) + ) + stmt = CreateTableAs(sel, "dst_union") + ctas_manager(stmt) + + vals = [ + r[0] + for r in connection.exec_driver_sql( + "SELECT id FROM dst_union ORDER BY id" + ).fetchall() + ] + assert vals == [1, 2, 3] + + class InsertTest(fixtures.TestBase, AssertsExecutionResults): """Tests inserts and autoincrement.""" diff --git a/test/sql/test_create_table_as.py b/test/sql/test_create_table_as.py new file mode 100644 index 0000000000..df2f5140f3 --- /dev/null +++ b/test/sql/test_create_table_as.py @@ -0,0 +1,281 @@ +import re + +from sqlalchemy import bindparam +from sqlalchemy import literal +from sqlalchemy.engine import default as default_engine +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql import column +from sqlalchemy.sql import select +from sqlalchemy.sql import table +from sqlalchemy.sql.ddl import CreateTableAs +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.assertions import AssertsCompiledSQL +from sqlalchemy.testing.assertions import expect_raises_message + + +class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): + __dialect__ = "default" + + def _source(self): + return table("src", column("id"), column("name")) + + def assert_inner_params(self, stmt, expected, dialect=None): + d = default_engine.DefaultDialect() if dialect is None else dialect + inner = stmt.selectable.compile(dialect=d) + assert ( + inner.params == expected + ), f"Got {inner.params}, expected {expected}" + + def test_basic_element(self): + src = self._source() + stmt = CreateTableAs( + select(src.c.id, src.c.name).select_from(src), + "dst", + ) + self.assert_compile( + stmt, + "CREATE TABLE dst AS SELECT src.id, src.name FROM src", + ) + + def test_schema_element_qualified(self): + src = self._source() + stmt = CreateTableAs( + select(src.c.id).select_from(src), + "dst", + schema="analytics", + ) + self.assert_compile( + stmt, + "CREATE TABLE analytics.dst AS SELECT src.id FROM src", + ) + + def test_blank_schema_treated_as_none(self): + src = self._source() + stmt = CreateTableAs( + select(src.c.id).select_from(src), "dst", schema="" + ) + self.assert_compile(stmt, "CREATE TABLE dst AS SELECT src.id FROM src") + + def test_binds_preserved(self): + src = self._source() + stmt = CreateTableAs( + select(bindparam("tag", value="x").label("tag")).select_from(src), + "dst", + ) + self.assert_compile( + stmt, + "CREATE TABLE dst AS SELECT :tag AS tag FROM src", + ) + self.assert_inner_params(stmt, {"tag": "x"}) + + def test_flags_not_rendered_in_default(self): + src = self._source() + stmt = CreateTableAs( + select(src.c.id).select_from(src), + "dst", + schema="sch", + temporary=True, + if_not_exists=True, + ) + # Default baseline omits TEMPORARY / IF NOT EXISTS; dialects add them. + self.assert_compile( + stmt, + "CREATE TABLE sch.dst AS SELECT src.id FROM src", + ) + + def test_join_with_binds_preserved(self): + a = table("a", column("id"), column("name")) + b = table("b", column("id"), column("status")) + + s = ( + select(a.c.id, a.c.name) + .select_from(a.join(b, a.c.id == b.c.id)) + .where(b.c.status == bindparam("p_status")) + ).into("dest") + + # Ensure WHERE survives into CTAS and params are preserved + self.assert_compile( + s, + "CREATE TABLE dest AS " + "SELECT a.id, a.name FROM a JOIN b ON a.id = b.id " + "WHERE b.status = :p_status", + ) + self.assert_inner_params(s, {"p_status": None}) + + def test_into_equivalent_to_element(self): + src = self._source() + s = select(src.c.id).select_from(src).where(src.c.id == bindparam("p")) + via_into = s.into("dst") + via_element = CreateTableAs(s, "dst") + + self.assert_compile( + via_into, + "CREATE TABLE dst AS SELECT src.id FROM src WHERE src.id = :p", + ) + self.assert_compile( + via_element, + "CREATE TABLE dst AS SELECT src.id FROM src WHERE src.id = :p", + ) + # Param parity (inner SELECT of both) + self.assert_inner_params(via_into, {"p": None}) + self.assert_inner_params(via_element, {"p": None}) + + def test_into_does_not_mutate_original_select(self): + src = self._source() + s = select(src.c.id).select_from(src).where(src.c.id == 5) + + # compile original SELECT + self.assert_compile( + s, + "SELECT src.id FROM src WHERE src.id = :id_1", + ) + + # build CTAS + _ = s.into("dst") + + # original is still a SELECT + self.assert_compile( + s, + "SELECT src.id FROM src WHERE src.id = :id_1", + ) + + def test_into_with_schema_argument(self): + src = self._source() + s = select(src.c.id).select_from(src).into("t", schema="analytics") + self.assert_compile( + s, + "CREATE TABLE analytics.t AS SELECT src.id FROM src", + ) + + def test_target_table_without_schema_accepts_schema_kw(self): + tgt = table("dst") + + s = select(bindparam("v", value=1).label("anon_1")).select_from( + table("x") + ) + + stmt = CreateTableAs( + s, + tgt, + schema="sch", + ) + self.assert_compile( + stmt, + "CREATE TABLE sch.dst AS SELECT :v AS anon_1 FROM x", + ) + self.assert_inner_params(stmt, {"v": 1}) + + def test_target_as_table_with_schema_and_conflict(self): + # Target object with schema set + tgt = table("dst", schema="sch") + + # Conflicting schema in ctor should raise ArgumentError + with expect_raises_message( + ArgumentError, + r"Conflicting schema", + ): + CreateTableAs( + select(literal(1)).select_from(table("x")), + tgt, + schema="other", + ) + + def test_target_string_must_be_unqualified(self): + src = self._source() + with expect_raises_message( + ArgumentError, + re.escape("Target string must be unqualified (use schema=)."), + ): + CreateTableAs(select(src.c.id).select_from(src), "sch.dst") + + def test_empty_name(self): + with expect_raises_message( + ArgumentError, "Table name must be non-empty" + ): + CreateTableAs(select(literal(1)), "") + + def test_generated_table_property(self): + src = self._source() + stmt = CreateTableAs( + select(src.c.id).select_from(src), "dst", schema="sch" + ) + gt = stmt.generated_table + assert gt.name == "dst" + assert gt.schema == "sch" + + def test_labels_in_select_list_preserved(self): + src = self._source() + stmt = CreateTableAs( + select( + src.c.id.label("user_id"), src.c.name.label("user_name") + ).select_from(src), + "dst", + ) + self.assert_compile( + stmt, + "CREATE TABLE dst AS " + "SELECT src.id AS user_id, src.name AS user_name FROM src", + ) + + def test_distinct_and_group_by_survive(self): + src = self._source() + sel = ( + select(src.c.name).select_from(src).distinct().group_by(src.c.name) + ) + stmt = CreateTableAs(sel, "dst") + self.assert_compile( + stmt, + "CREATE TABLE dst AS " + "SELECT DISTINCT src.name FROM src GROUP BY src.name", + ) + + def test_union_all_with_binds_preserved(self): + a = table("a", column("id")) + b = table("b", column("id")) + + # Named binds so params are deterministic + s1 = ( + select(a.c.id) + .select_from(a) + .where(a.c.id == bindparam("p_a", value=1)) + ) + s2 = ( + select(b.c.id) + .select_from(b) + .where(b.c.id == bindparam("p_b", value=2)) + ) + + u_all = s1.union_all(s2) + stmt = CreateTableAs(u_all, "dst") + + self.assert_compile( + stmt, + "CREATE TABLE dst AS " + "SELECT a.id FROM a WHERE a.id = :p_a " + "UNION ALL SELECT b.id FROM b WHERE b.id = :p_b", + ) + + self.assert_inner_params(stmt, {"p_a": 1, "p_b": 2}) + + def test_union_labels_follow_first_select(self): + # Many engines take column names + # of a UNION from the first SELECT’s labels. + a = table("a", column("val")) + b = table("b", column("val")) + + s1 = select(a.c.val.label("first_name")).select_from(a) + s2 = select(b.c.val).select_from(b) # unlabeled second branch + + u = s1.union(s2) + stmt = CreateTableAs(u, "dst") + + # We only assert what’s stable across dialects: + # - first SELECT has the label + # - a UNION occurs + self.assert_compile( + stmt, + "CREATE TABLE dst AS " + "SELECT a.val AS first_name FROM a " + "UNION " + "SELECT b.val FROM b", + )