From: Mike Bayer Date: Thu, 15 Dec 2022 15:22:36 +0000 (-0500) Subject: implement literal_binds with expanding + bind_expression X-Git-Tag: rel_1_4_46~19 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=84ba8874e146bcdbf46ce70ece32c4c224c3fd44;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement literal_binds with expanding + bind_expression Fixed bug where SQL compilation would fail (assertion fail in 2.0, NoneType error in 1.4) when using an expression whose type included :meth:`_types.TypeEngine.bind_expression`, in the context of an "expanding" (i.e. "IN") parameter in conjunction with the ``literal_binds`` compiler parameter. Fixes: #8989 Change-Id: Ic9fd27b46381b488117295ea5a492d8fc158e39f (cherry picked from commit 8c6de3c2c43ab372cbbe76464b4c5be3b6457252) --- diff --git a/doc/build/changelog/unreleased_14/8989.rst b/doc/build/changelog/unreleased_14/8989.rst new file mode 100644 index 0000000000..4c38fdf019 --- /dev/null +++ b/doc/build/changelog/unreleased_14/8989.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, types + :tickets: 8989 + :versions: 2.0.0b5 + + Fixed bug where SQL compilation would fail (assertion fail in 2.0, NoneType + error in 1.4) when using an expression whose type included + :meth:`_types.TypeEngine.bind_expression`, in the context of an "expanding" + (i.e. "IN") parameter in conjunction with the ``literal_binds`` compiler + parameter. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8fbf3092aa..cb30c77738 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -699,6 +699,8 @@ class SQLCompiler(Compiled): """ + _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]") + positiontup = None """for a compiled construct that uses a positional paramstyle, will be a sequence of strings, indicating the names of bound parameters in order. @@ -1294,7 +1296,7 @@ class SQLCompiler(Compiled): return expr statement = re.sub( - r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]", + self._post_compile_pattern, process_expanding, self.string, ) @@ -2094,12 +2096,16 @@ class SQLCompiler(Compiled): ) def _literal_execute_expanding_parameter_literal_binds( - self, parameter, values + self, parameter, values, bind_expression_template=None ): typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) if not values: + # empty IN expression. note we don't need to use + # bind_expression_template here because there are no + # expressions to render. + if typ_dialect_impl._is_tuple_type: replacement_expression = ( "VALUES " if self.dialect.tuple_in_values else "" @@ -2120,6 +2126,12 @@ class SQLCompiler(Compiled): ) ): + if typ_dialect_impl._has_bind_expression: + raise NotImplementedError( + "bind_expression() on TupleType not supported with " + "literal_binds" + ) + replacement_expression = ( "VALUES " if self.dialect.tuple_in_values else "" ) + ", ".join( @@ -2135,10 +2147,29 @@ class SQLCompiler(Compiled): for i, tuple_element in enumerate(values) ) else: - replacement_expression = ", ".join( - self.render_literal_value(value, parameter.type) - for value in values - ) + if bind_expression_template: + post_compile_pattern = self._post_compile_pattern + m = post_compile_pattern.search(bind_expression_template) + assert m and m.group( + 2 + ), "unexpected format for expanding parameter" + + tok = m.group(2).split("~~") + be_left, be_right = tok[1], tok[3] + replacement_expression = ", ".join( + "%s%s%s" + % ( + be_left, + self.render_literal_value(value, parameter.type), + be_right, + ) + for value in values + ) + else: + replacement_expression = ", ".join( + self.render_literal_value(value, parameter.type) + for value in values + ) return (), replacement_expression @@ -2453,7 +2484,7 @@ class SQLCompiler(Compiled): bind_expression, skip_bind_expression=True, within_columns_clause=within_columns_clause, - literal_binds=literal_binds, + literal_binds=literal_binds and not bindparam.expanding, literal_execute=literal_execute, render_postcompile=render_postcompile, **kwargs @@ -2461,14 +2492,26 @@ class SQLCompiler(Compiled): if bindparam.expanding: # for postcompile w/ expanding, move the "wrapped" part # of this into the inside + m = re.match( r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped ) + assert m, "unexpected format for expanding parameter" wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % ( m.group(2), m.group(1), m.group(3), ) + + if literal_binds: + ret = self.render_literal_bindparam( + bindparam, + within_columns_clause=True, + bind_expression_template=wrapped, + **kwargs + ) + return "(%s)" % ret + return wrapped if not literal_binds: @@ -2568,7 +2611,11 @@ class SQLCompiler(Compiled): return ret def render_literal_bindparam( - self, bindparam, render_literal_value=NO_ARG, **kw + self, + bindparam, + render_literal_value=NO_ARG, + bind_expression_template=None, + **kw ): if render_literal_value is not NO_ARG: value = render_literal_value @@ -2587,7 +2634,11 @@ class SQLCompiler(Compiled): if bindparam.expanding: leep = self._literal_execute_expanding_parameter_literal_binds - to_update, replacement_expr = leep(bindparam, value) + to_update, replacement_expr = leep( + bindparam, + value, + bind_expression_template=bind_expression_template, + ) return replacement_expr else: return self.render_literal_value(value, bindparam.type) diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py index e0e0858a45..7c21926207 100644 --- a/test/sql/test_type_expressions.py +++ b/test/sql/test_type_expressions.py @@ -182,28 +182,40 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): "test_table WHERE test_table.y = lower(:y_1)", ) - def test_in_binds(self): + @testing.variation( + "compile_opt", ["plain", "postcompile", "literal_binds"] + ) + def test_in_binds(self, compile_opt): table = self._fixture() - self.assert_compile( - select(table).where( - table.c.y.in_(["hi", "there", "some", "expr"]) - ), - "SELECT test_table.x, lower(test_table.y) AS y FROM " - "test_table WHERE test_table.y IN " - "(__[POSTCOMPILE_y_1~~lower(~~REPL~~)~~])", - render_postcompile=False, + stmt = select(table).where( + table.c.y.in_(["hi", "there", "some", "expr"]) ) - self.assert_compile( - select(table).where( - table.c.y.in_(["hi", "there", "some", "expr"]) - ), - "SELECT test_table.x, lower(test_table.y) AS y FROM " - "test_table WHERE test_table.y IN " - "(lower(:y_1_1), lower(:y_1_2), lower(:y_1_3), lower(:y_1_4))", - render_postcompile=True, - ) + if compile_opt.plain: + self.assert_compile( + stmt, + "SELECT test_table.x, lower(test_table.y) AS y FROM " + "test_table WHERE test_table.y IN " + "(__[POSTCOMPILE_y_1~~lower(~~REPL~~)~~])", + render_postcompile=False, + ) + elif compile_opt.postcompile: + self.assert_compile( + stmt, + "SELECT test_table.x, lower(test_table.y) AS y FROM " + "test_table WHERE test_table.y IN " + "(lower(:y_1_1), lower(:y_1_2), lower(:y_1_3), lower(:y_1_4))", + render_postcompile=True, + ) + elif compile_opt.literal_binds: + self.assert_compile( + stmt, + "SELECT test_table.x, lower(test_table.y) AS y FROM " + "test_table WHERE test_table.y IN " + "(lower('hi'), lower('there'), lower('some'), lower('expr'))", + literal_binds=True, + ) def test_dialect(self): table = self._fixture()