]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Propagate compiler kw for visit_values to parameters
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Apr 2021 19:09:51 +0000 (15:09 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Apr 2021 19:28:45 +0000 (15:28 -0400)
Fixed issue in SQL compiler where the bound parameters set up for a
:class:`.Values` construct wouldn't be positionally tracked correctly if
inside of a :class:`_sql.CTE`, affecting database drivers that support
VALUES + ctes and use positional parameters such as SQL Server in
particular as well as asyncpg.   The fix also repairs support for
compiler flags such as ``literal_binds``.

Fixes: #6327
Change-Id: I2d549228691d0bfc10dadd0955b1549d7584db51

doc/build/changelog/unreleased_14/6327.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/sql/test_values.py

diff --git a/doc/build/changelog/unreleased_14/6327.rst b/doc/build/changelog/unreleased_14/6327.rst
new file mode 100644 (file)
index 0000000..c5544a1
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 6327
+
+    Fixed issue in SQL compiler where the bound parameters set up for a
+    :class:`.Values` construct wouldn't be positionally tracked correctly if
+    inside of a :class:`_sql.CTE`, affecting database drivers that support
+    VALUES + ctes and use positional parameters such as SQL Server in
+    particular as well as asyncpg.   The fix also repairs support for
+    compiler flags such as ``literal_binds``.
index 84df0837c619b34b869fa6ab11d66dd0899b0291..4c591a87f2d6b2e19a4fbe36d4ca2ece59a98385 100644 (file)
@@ -2724,13 +2724,13 @@ class SQLCompiler(Compiled):
         return text
 
     def visit_values(self, element, asfrom=False, from_linter=None, **kw):
-
+        kw.setdefault("literal_binds", element.literal_binds)
         v = "VALUES %s" % ", ".join(
             self.process(
                 elements.Tuple(
                     types=element._column_types, *elem
                 ).self_group(),
-                literal_binds=element.literal_binds,
+                **kw
             )
             for chunk in element._data
             for elem in chunk
index 43e8f85316282e6ed514a393453483ba3a538a81..dcd32a6791ae30a6ab5dc9268001336824c386ae 100644 (file)
@@ -85,6 +85,69 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL):
             'AS "Spaces and Cases" ("CaseSensitive", "has spaces", number)',
         )
 
+    def test_values_in_cte_params(self):
+        cte1 = select(
+            Values(
+                column("col1", String),
+                column("col2", Integer),
+                name="temp_table",
+            ).data([("a", 2), ("b", 3)])
+        ).cte("cte1")
+
+        cte2 = select(cte1.c.col1).where(cte1.c.col1 == "q").cte("cte2")
+        stmt = select(cte2.c.col1)
+
+        dialect = default.DefaultDialect()
+        dialect.positional = True
+        dialect.paramstyle = "numeric"
+        self.assert_compile(
+            stmt,
+            "WITH cte1 AS (SELECT temp_table.col1 AS col1, "
+            "temp_table.col2 AS col2 FROM (VALUES (:1, :2), (:3, :4)) AS "
+            "temp_table (col1, col2)), "
+            "cte2 AS "
+            "(SELECT cte1.col1 AS col1 FROM cte1 WHERE cte1.col1 = :5) "
+            "SELECT cte2.col1 FROM cte2",
+            checkpositional=("a", 2, "b", 3, "q"),
+            dialect=dialect,
+        )
+
+        self.assert_compile(
+            stmt,
+            "WITH cte1 AS (SELECT temp_table.col1 AS col1, "
+            "temp_table.col2 AS col2 FROM (VALUES ('a', 2), ('b', 3)) "
+            "AS temp_table (col1, col2)), "
+            "cte2 AS "
+            "(SELECT cte1.col1 AS col1 FROM cte1 WHERE cte1.col1 = 'q') "
+            "SELECT cte2.col1 FROM cte2",
+            literal_binds=True,
+            dialect=dialect,
+        )
+
+    def test_values_in_cte_literal_binds(self):
+        cte1 = select(
+            Values(
+                column("col1", String),
+                column("col2", Integer),
+                name="temp_table",
+                literal_binds=True,
+            ).data([("a", 2), ("b", 3)])
+        ).cte("cte1")
+
+        cte2 = select(cte1.c.col1).where(cte1.c.col1 == "q").cte("cte2")
+        stmt = select(cte2.c.col1)
+
+        self.assert_compile(
+            stmt,
+            "WITH cte1 AS (SELECT temp_table.col1 AS col1, "
+            "temp_table.col2 AS col2 FROM (VALUES ('a', 2), ('b', 3)) "
+            "AS temp_table (col1, col2)), "
+            "cte2 AS "
+            "(SELECT cte1.col1 AS col1 FROM cte1 WHERE cte1.col1 = :col1_1) "
+            "SELECT cte2.col1 FROM cte2",
+            checkparams={"col1_1": "q"},
+        )
+
     @testing.fixture
     def literal_parameter_fixture(self):
         def go(literal_binds, omit=None):