]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use fixture for CreateTableAs default dialect tests 12860/head
authorGreg Jarzab <greg.jarzab@gmail.com>
Tue, 23 Sep 2025 04:09:17 +0000 (23:09 -0500)
committerGreg Jarzab <greg.jarzab@gmail.com>
Tue, 23 Sep 2025 04:09:17 +0000 (23:09 -0500)
lib/sqlalchemy/sql/ddl.py
test/sql/test_create_table_as.py

index 7c0c57b1c561738c0972f61b03aa9fc8ca4bcdf6..0c1752c03439c315450c4b3a547df5e8951e39ea 100644 (file)
@@ -555,12 +555,12 @@ class CreateTableAs(ExecutableDDLElement):
     :param selectable: :class:`_sql.Selectable`
         The SELECT (or other selectable) providing the columns and rows.
 
-    :param target: str | :class:`_sql.TableClause`
+    :param element: str | :class:`_sql.TableClause`
         Table name or object. If passed as a string, it must be
         unqualified; use the ``schema`` argument for qualification.
 
     :param schema: str, optional
-        Schema or owner name.  If both ``schema`` and the target object
+        Schema or owner name.  If both ``schema`` and the element object
         specify a schema, they must match.
 
     :param temporary: bool, default False.
@@ -599,7 +599,7 @@ class CreateTableAs(ExecutableDDLElement):
                 and schema != t_schema
             ):
                 raise exc.ArgumentError(
-                    f"Conflicting schema: target={t_schema!r}, "
+                    f"Conflicting schema: element={t_schema!r}, "
                     f"schema={schema!r}"
                 )
             final_schema = (
index df2f5140f3783ca8e0a1f0afcc9d7b70f0bff50b..815d141fa75a8ee77ccce36a4ce6f06f7157c8cd 100644 (file)
@@ -2,6 +2,7 @@ import re
 
 from sqlalchemy import bindparam
 from sqlalchemy import literal
+from sqlalchemy import testing
 from sqlalchemy.engine import default as default_engine
 from sqlalchemy.exc import ArgumentError
 from sqlalchemy.sql import column
@@ -16,9 +17,16 @@ from sqlalchemy.testing.assertions import expect_raises_message
 class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
-    def _source(self):
+    @testing.fixture
+    def src_table(self):
         return table("src", column("id"), column("name"))
 
+    @testing.fixture
+    def src_two_tables(self):
+        a = table("a", column("id"), column("name"))
+        b = table("b", column("id"), column("status"))
+        return a, b
+
     def assert_inner_params(self, stmt, expected, dialect=None):
         d = default_engine.DefaultDialect() if dialect is None else dialect
         inner = stmt.selectable.compile(dialect=d)
@@ -26,8 +34,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
             inner.params == expected
         ), f"Got {inner.params}, expected {expected}"
 
-    def test_basic_element(self):
-        src = self._source()
+    def test_basic_element(self, src_table):
+        src = src_table
         stmt = CreateTableAs(
             select(src.c.id, src.c.name).select_from(src),
             "dst",
@@ -37,8 +45,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
             "CREATE TABLE dst AS SELECT src.id, src.name FROM src",
         )
 
-    def test_schema_element_qualified(self):
-        src = self._source()
+    def test_schema_element_qualified(self, src_table):
+        src = src_table
         stmt = CreateTableAs(
             select(src.c.id).select_from(src),
             "dst",
@@ -49,15 +57,15 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
             "CREATE TABLE analytics.dst AS SELECT src.id FROM src",
         )
 
-    def test_blank_schema_treated_as_none(self):
-        src = self._source()
+    def test_blank_schema_treated_as_none(self, src_table):
+        src = src_table
         stmt = CreateTableAs(
             select(src.c.id).select_from(src), "dst", schema=""
         )
         self.assert_compile(stmt, "CREATE TABLE dst AS SELECT src.id FROM src")
 
-    def test_binds_preserved(self):
-        src = self._source()
+    def test_binds_preserved(self, src_table):
+        src = src_table
         stmt = CreateTableAs(
             select(bindparam("tag", value="x").label("tag")).select_from(src),
             "dst",
@@ -68,8 +76,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
         )
         self.assert_inner_params(stmt, {"tag": "x"})
 
-    def test_flags_not_rendered_in_default(self):
-        src = self._source()
+    def test_flags_not_rendered_in_default(self, src_table):
+        src = src_table
         stmt = CreateTableAs(
             select(src.c.id).select_from(src),
             "dst",
@@ -83,28 +91,31 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
             "CREATE TABLE sch.dst AS SELECT src.id FROM src",
         )
 
-    def test_join_with_binds_preserved(self):
-        a = table("a", column("id"), column("name"))
-        b = table("b", column("id"), column("status"))
+    def test_join_with_binds_preserved(self, src_two_tables):
+        a, b = src_two_tables
 
         s = (
             select(a.c.id, a.c.name)
             .select_from(a.join(b, a.c.id == b.c.id))
-            .where(b.c.status == bindparam("p_status"))
-        ).into("dest")
+            .where(b.c.status == bindparam("p_status", value="active"))
+        ).into("dst")
 
         # Ensure WHERE survives into CTAS and params are preserved
         self.assert_compile(
             s,
-            "CREATE TABLE dest AS "
+            "CREATE TABLE dst AS "
             "SELECT a.id, a.name FROM a JOIN b ON a.id = b.id "
             "WHERE b.status = :p_status",
         )
-        self.assert_inner_params(s, {"p_status": None})
+        self.assert_inner_params(s, {"p_status": "active"})
 
-    def test_into_equivalent_to_element(self):
-        src = self._source()
-        s = select(src.c.id).select_from(src).where(src.c.id == bindparam("p"))
+    def test_into_equivalent_to_element(self, src_table):
+        src = src_table
+        s = (
+            select(src.c.id)
+            .select_from(src)
+            .where(src.c.id == bindparam("p", value=2))
+        )
         via_into = s.into("dst")
         via_element = CreateTableAs(s, "dst")
 
@@ -117,11 +128,11 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
             "CREATE TABLE dst AS SELECT src.id FROM src WHERE src.id = :p",
         )
         # Param parity (inner SELECT of both)
-        self.assert_inner_params(via_into, {"p": None})
-        self.assert_inner_params(via_element, {"p": None})
+        self.assert_inner_params(via_into, {"p": 2})
+        self.assert_inner_params(via_element, {"p": 2})
 
-    def test_into_does_not_mutate_original_select(self):
-        src = self._source()
+    def test_into_does_not_mutate_original_select(self, src_table):
+        src = src_table
         s = select(src.c.id).select_from(src).where(src.c.id == 5)
 
         # compile original SELECT
@@ -139,8 +150,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT src.id FROM src WHERE src.id = :id_1",
         )
 
-    def test_into_with_schema_argument(self):
-        src = self._source()
+    def test_into_with_schema_argument(self, src_table):
+        src = src_table
         s = select(src.c.id).select_from(src).into("t", schema="analytics")
         self.assert_compile(
             s,
@@ -180,8 +191,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
                 schema="other",
             )
 
-    def test_target_string_must_be_unqualified(self):
-        src = self._source()
+    def test_target_string_must_be_unqualified(self, src_table):
+        src = src_table
         with expect_raises_message(
             ArgumentError,
             re.escape("Target string must be unqualified (use schema=)."),
@@ -194,8 +205,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
         ):
             CreateTableAs(select(literal(1)), "")
 
-    def test_generated_table_property(self):
-        src = self._source()
+    def test_generated_table_property(self, src_table):
+        src = src_table
         stmt = CreateTableAs(
             select(src.c.id).select_from(src), "dst", schema="sch"
         )
@@ -203,8 +214,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
         assert gt.name == "dst"
         assert gt.schema == "sch"
 
-    def test_labels_in_select_list_preserved(self):
-        src = self._source()
+    def test_labels_in_select_list_preserved(self, src_table):
+        src = src_table
         stmt = CreateTableAs(
             select(
                 src.c.id.label("user_id"), src.c.name.label("user_name")
@@ -217,8 +228,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT src.id AS user_id, src.name AS user_name FROM src",
         )
 
-    def test_distinct_and_group_by_survive(self):
-        src = self._source()
+    def test_distinct_and_group_by_survive(self, src_table):
+        src = src_table
         sel = (
             select(src.c.name).select_from(src).distinct().group_by(src.c.name)
         )
@@ -229,9 +240,8 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT DISTINCT src.name FROM src GROUP BY src.name",
         )
 
-    def test_union_all_with_binds_preserved(self):
-        a = table("a", column("id"))
-        b = table("b", column("id"))
+    def test_union_all_with_binds_preserved(self, src_two_tables):
+        a, b = src_two_tables
 
         # Named binds so params are deterministic
         s1 = (
@@ -257,7 +267,7 @@ class CreateTableAsDefaultDialectTest(fixtures.TestBase, AssertsCompiledSQL):
 
         self.assert_inner_params(stmt, {"p_a": 1, "p_b": 2})
 
-    def test_union_labels_follow_first_select(self):
+    def test_union_labels_follow_first_select(self, src_two_tables):
         # Many engines take column names
         # of a UNION from the first SELECT’s labels.
         a = table("a", column("val"))