From: Mike Bayer Date: Tue, 20 Apr 2021 19:09:51 +0000 (-0400) Subject: Propagate compiler kw for visit_values to parameters X-Git-Tag: rel_1_4_10~2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=3a0a6e1db43a4aefd3570f2956ae2567e3062a77;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Propagate compiler kw for visit_values to parameters Fixed issue in SQL compiler where the bound parameters set up for a :class:`.Values` construct wouldn't be positionally tracked correctly if inside of a :class:`_sql.CTE`, affecting database drivers that support VALUES + ctes and use positional parameters such as SQL Server in particular as well as asyncpg. The fix also repairs support for compiler flags such as ``literal_binds``. Fixes: #6327 Change-Id: I2d549228691d0bfc10dadd0955b1549d7584db51 --- diff --git a/doc/build/changelog/unreleased_14/6327.rst b/doc/build/changelog/unreleased_14/6327.rst new file mode 100644 index 0000000000..c5544a1ee6 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6327.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, sql + :tickets: 6327 + + Fixed issue in SQL compiler where the bound parameters set up for a + :class:`.Values` construct wouldn't be positionally tracked correctly if + inside of a :class:`_sql.CTE`, affecting database drivers that support + VALUES + ctes and use positional parameters such as SQL Server in + particular as well as asyncpg. The fix also repairs support for + compiler flags such as ``literal_binds``. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 84df0837c6..4c591a87f2 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2724,13 +2724,13 @@ class SQLCompiler(Compiled): return text def visit_values(self, element, asfrom=False, from_linter=None, **kw): - + kw.setdefault("literal_binds", element.literal_binds) v = "VALUES %s" % ", ".join( self.process( elements.Tuple( types=element._column_types, *elem ).self_group(), - literal_binds=element.literal_binds, + **kw ) for chunk in element._data for elem in chunk diff --git a/test/sql/test_values.py b/test/sql/test_values.py index 43e8f85316..dcd32a6791 100644 --- a/test/sql/test_values.py +++ b/test/sql/test_values.py @@ -85,6 +85,69 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL): 'AS "Spaces and Cases" ("CaseSensitive", "has spaces", number)', ) + def test_values_in_cte_params(self): + cte1 = select( + Values( + column("col1", String), + column("col2", Integer), + name="temp_table", + ).data([("a", 2), ("b", 3)]) + ).cte("cte1") + + cte2 = select(cte1.c.col1).where(cte1.c.col1 == "q").cte("cte2") + stmt = select(cte2.c.col1) + + dialect = default.DefaultDialect() + dialect.positional = True + dialect.paramstyle = "numeric" + self.assert_compile( + stmt, + "WITH cte1 AS (SELECT temp_table.col1 AS col1, " + "temp_table.col2 AS col2 FROM (VALUES (:1, :2), (:3, :4)) AS " + "temp_table (col1, col2)), " + "cte2 AS " + "(SELECT cte1.col1 AS col1 FROM cte1 WHERE cte1.col1 = :5) " + "SELECT cte2.col1 FROM cte2", + checkpositional=("a", 2, "b", 3, "q"), + dialect=dialect, + ) + + self.assert_compile( + stmt, + "WITH cte1 AS (SELECT temp_table.col1 AS col1, " + "temp_table.col2 AS col2 FROM (VALUES ('a', 2), ('b', 3)) " + "AS temp_table (col1, col2)), " + "cte2 AS " + "(SELECT cte1.col1 AS col1 FROM cte1 WHERE cte1.col1 = 'q') " + "SELECT cte2.col1 FROM cte2", + literal_binds=True, + dialect=dialect, + ) + + def test_values_in_cte_literal_binds(self): + cte1 = select( + Values( + column("col1", String), + column("col2", Integer), + name="temp_table", + literal_binds=True, + ).data([("a", 2), ("b", 3)]) + ).cte("cte1") + + cte2 = select(cte1.c.col1).where(cte1.c.col1 == "q").cte("cte2") + stmt = select(cte2.c.col1) + + self.assert_compile( + stmt, + "WITH cte1 AS (SELECT temp_table.col1 AS col1, " + "temp_table.col2 AS col2 FROM (VALUES ('a', 2), ('b', 3)) " + "AS temp_table (col1, col2)), " + "cte2 AS " + "(SELECT cte1.col1 AS col1 FROM cte1 WHERE cte1.col1 = :col1_1) " + "SELECT cte2.col1 FROM cte2", + checkparams={"col1_1": "q"}, + ) + @testing.fixture def literal_parameter_fixture(self): def go(literal_binds, omit=None):