From 7de8a109b892fd91222ce2f59c388ca275021ddb Mon Sep 17 00:00:00 2001 From: Greg Jarzab Date: Mon, 22 Sep 2025 23:09:17 -0500 Subject: [PATCH] use fixture for CreateTableAs default dialect tests --- lib/sqlalchemy/sql/ddl.py | 6 +-- test/sql/test_create_table_as.py | 88 ++++++++++++++++++-------------- 2 files changed, 52 insertions(+), 42 deletions(-) diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 7c0c57b1c5..0c1752c034 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -555,12 +555,12 @@ class CreateTableAs(ExecutableDDLElement): :param selectable: :class:`_sql.Selectable` The SELECT (or other selectable) providing the columns and rows. - :param target: str | :class:`_sql.TableClause` + :param element: str | :class:`_sql.TableClause` Table name or object. If passed as a string, it must be unqualified; use the ``schema`` argument for qualification. :param schema: str, optional - Schema or owner name. If both ``schema`` and the target object + Schema or owner name. If both ``schema`` and the element object specify a schema, they must match. :param temporary: bool, default False. @@ -599,7 +599,7 @@ class CreateTableAs(ExecutableDDLElement): and schema != t_schema ): raise exc.ArgumentError( - f"Conflicting schema: target={t_schema!r}, " + f"Conflicting schema: element={t_schema!r}, " f"schema={schema!r}" ) final_schema = ( diff --git a/test/sql/test_create_table_as.py b/test/sql/test_create_table_as.py index df2f5140f3..815d141fa7 100644 --- a/test/sql/test_create_table_as.py +++ b/test/sql/test_create_table_as.py @@ -2,6 +2,7 @@ import re from sqlalchemy import bindparam from sqlalchemy import literal +from sqlalchemy import testing from sqlalchemy.engine import default as default_engine from sqlalchemy.exc import ArgumentError from sqlalchemy.sql import column @@ -16,9 +17,16 @@ from sqlalchemy.testing.assertions import expect_raises_message class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" - def _source(self): + @testing.fixture + def src_table(self): return table("src", column("id"), column("name")) + @testing.fixture + def src_two_tables(self): + a = table("a", column("id"), column("name")) + b = table("b", column("id"), column("status")) + return a, b + def assert_inner_params(self, stmt, expected, dialect=None): d = default_engine.DefaultDialect() if dialect is None else dialect inner = stmt.selectable.compile(dialect=d) @@ -26,8 +34,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): inner.params == expected ), f"Got {inner.params}, expected {expected}" - def test_basic_element(self): - src = self._source() + def test_basic_element(self, src_table): + src = src_table stmt = CreateTableAs( select(src.c.id, src.c.name).select_from(src), "dst", @@ -37,8 +45,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): "CREATE TABLE dst AS SELECT src.id, src.name FROM src", ) - def test_schema_element_qualified(self): - src = self._source() + def test_schema_element_qualified(self, src_table): + src = src_table stmt = CreateTableAs( select(src.c.id).select_from(src), "dst", @@ -49,15 +57,15 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): "CREATE TABLE analytics.dst AS SELECT src.id FROM src", ) - def test_blank_schema_treated_as_none(self): - src = self._source() + def test_blank_schema_treated_as_none(self, src_table): + src = src_table stmt = CreateTableAs( select(src.c.id).select_from(src), "dst", schema="" ) self.assert_compile(stmt, "CREATE TABLE dst AS SELECT src.id FROM src") - def test_binds_preserved(self): - src = self._source() + def test_binds_preserved(self, src_table): + src = src_table stmt = CreateTableAs( select(bindparam("tag", value="x").label("tag")).select_from(src), "dst", @@ -68,8 +76,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): ) self.assert_inner_params(stmt, {"tag": "x"}) - def test_flags_not_rendered_in_default(self): - src = self._source() + def test_flags_not_rendered_in_default(self, src_table): + src = src_table stmt = CreateTableAs( select(src.c.id).select_from(src), "dst", @@ -83,28 +91,31 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): "CREATE TABLE sch.dst AS SELECT src.id FROM src", ) - def test_join_with_binds_preserved(self): - a = table("a", column("id"), column("name")) - b = table("b", column("id"), column("status")) + def test_join_with_binds_preserved(self, src_two_tables): + a, b = src_two_tables s = ( select(a.c.id, a.c.name) .select_from(a.join(b, a.c.id == b.c.id)) - .where(b.c.status == bindparam("p_status")) - ).into("dest") + .where(b.c.status == bindparam("p_status", value="active")) + ).into("dst") # Ensure WHERE survives into CTAS and params are preserved self.assert_compile( s, - "CREATE TABLE dest AS " + "CREATE TABLE dst AS " "SELECT a.id, a.name FROM a JOIN b ON a.id = b.id " "WHERE b.status = :p_status", ) - self.assert_inner_params(s, {"p_status": None}) + self.assert_inner_params(s, {"p_status": "active"}) - def test_into_equivalent_to_element(self): - src = self._source() - s = select(src.c.id).select_from(src).where(src.c.id == bindparam("p")) + def test_into_equivalent_to_element(self, src_table): + src = src_table + s = ( + select(src.c.id) + .select_from(src) + .where(src.c.id == bindparam("p", value=2)) + ) via_into = s.into("dst") via_element = CreateTableAs(s, "dst") @@ -117,11 +128,11 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): "CREATE TABLE dst AS SELECT src.id FROM src WHERE src.id = :p", ) # Param parity (inner SELECT of both) - self.assert_inner_params(via_into, {"p": None}) - self.assert_inner_params(via_element, {"p": None}) + self.assert_inner_params(via_into, {"p": 2}) + self.assert_inner_params(via_element, {"p": 2}) - def test_into_does_not_mutate_original_select(self): - src = self._source() + def test_into_does_not_mutate_original_select(self, src_table): + src = src_table s = select(src.c.id).select_from(src).where(src.c.id == 5) # compile original SELECT @@ -139,8 +150,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT src.id FROM src WHERE src.id = :id_1", ) - def test_into_with_schema_argument(self): - src = self._source() + def test_into_with_schema_argument(self, src_table): + src = src_table s = select(src.c.id).select_from(src).into("t", schema="analytics") self.assert_compile( s, @@ -180,8 +191,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): schema="other", ) - def test_target_string_must_be_unqualified(self): - src = self._source() + def test_target_string_must_be_unqualified(self, src_table): + src = src_table with expect_raises_message( ArgumentError, re.escape("Target string must be unqualified (use schema=)."), @@ -194,8 +205,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): ): CreateTableAs(select(literal(1)), "") - def test_generated_table_property(self): - src = self._source() + def test_generated_table_property(self, src_table): + src = src_table stmt = CreateTableAs( select(src.c.id).select_from(src), "dst", schema="sch" ) @@ -203,8 +214,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): assert gt.name == "dst" assert gt.schema == "sch" - def test_labels_in_select_list_preserved(self): - src = self._source() + def test_labels_in_select_list_preserved(self, src_table): + src = src_table stmt = CreateTableAs( select( src.c.id.label("user_id"), src.c.name.label("user_name") @@ -217,8 +228,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT src.id AS user_id, src.name AS user_name FROM src", ) - def test_distinct_and_group_by_survive(self): - src = self._source() + def test_distinct_and_group_by_survive(self, src_table): + src = src_table sel = ( select(src.c.name).select_from(src).distinct().group_by(src.c.name) ) @@ -229,9 +240,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): "SELECT DISTINCT src.name FROM src GROUP BY src.name", ) - def test_union_all_with_binds_preserved(self): - a = table("a", column("id")) - b = table("b", column("id")) + def test_union_all_with_binds_preserved(self, src_two_tables): + a, b = src_two_tables # Named binds so params are deterministic s1 = ( @@ -257,7 +267,7 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_inner_params(stmt, {"p_a": 1, "p_b": 2}) - def test_union_labels_follow_first_select(self): + def test_union_labels_follow_first_select(self, src_two_tables): # Many engines take column names # of a UNION from the first SELECT’s labels. a = table("a", column("val")) -- 2.47.3