]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement literal_binds with expanding + bind_expression
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Dec 2022 15:22:36 +0000 (10:22 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Dec 2022 15:34:05 +0000 (10:34 -0500)
Fixed bug where SQL compilation would fail (assertion fail in 2.0, NoneType
error in 1.4) when using an expression whose type included
:meth:`_types.TypeEngine.bind_expression`, in the context of an "expanding"
(i.e. "IN") parameter in conjunction with the ``literal_binds`` compiler
parameter.

Fixes: #8989
Change-Id: Ic9fd27b46381b488117295ea5a492d8fc158e39f
(cherry picked from commit 8c6de3c2c43ab372cbbe76464b4c5be3b6457252)

doc/build/changelog/unreleased_14/8989.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/sql/test_type_expressions.py

diff --git a/doc/build/changelog/unreleased_14/8989.rst b/doc/build/changelog/unreleased_14/8989.rst
new file mode 100644 (file)
index 0000000..4c38fdf
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, types
+    :tickets: 8989
+    :versions: 2.0.0b5
+
+    Fixed bug where SQL compilation would fail (assertion fail in 2.0, NoneType
+    error in 1.4) when using an expression whose type included
+    :meth:`_types.TypeEngine.bind_expression`, in the context of an "expanding"
+    (i.e. "IN") parameter in conjunction with the ``literal_binds`` compiler
+    parameter.
index 8fbf3092aaff5d4efa72fd7ec8820ba95242dce9..cb30c777389f4511d8c7a11bee6734349a6bbb79 100644 (file)
@@ -699,6 +699,8 @@ class SQLCompiler(Compiled):
 
     """
 
+    _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]")
+
     positiontup = None
     """for a compiled construct that uses a positional paramstyle, will be
     a sequence of strings, indicating the names of bound parameters in order.
@@ -1294,7 +1296,7 @@ class SQLCompiler(Compiled):
             return expr
 
         statement = re.sub(
-            r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]",
+            self._post_compile_pattern,
             process_expanding,
             self.string,
         )
@@ -2094,12 +2096,16 @@ class SQLCompiler(Compiled):
         )
 
     def _literal_execute_expanding_parameter_literal_binds(
-        self, parameter, values
+        self, parameter, values, bind_expression_template=None
     ):
 
         typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
 
         if not values:
+            # empty IN expression.  note we don't need to use
+            # bind_expression_template here because there are no
+            # expressions to render.
+
             if typ_dialect_impl._is_tuple_type:
                 replacement_expression = (
                     "VALUES " if self.dialect.tuple_in_values else ""
@@ -2120,6 +2126,12 @@ class SQLCompiler(Compiled):
             )
         ):
 
+            if typ_dialect_impl._has_bind_expression:
+                raise NotImplementedError(
+                    "bind_expression() on TupleType not supported with "
+                    "literal_binds"
+                )
+
             replacement_expression = (
                 "VALUES " if self.dialect.tuple_in_values else ""
             ) + ", ".join(
@@ -2135,10 +2147,29 @@ class SQLCompiler(Compiled):
                 for i, tuple_element in enumerate(values)
             )
         else:
-            replacement_expression = ", ".join(
-                self.render_literal_value(value, parameter.type)
-                for value in values
-            )
+            if bind_expression_template:
+                post_compile_pattern = self._post_compile_pattern
+                m = post_compile_pattern.search(bind_expression_template)
+                assert m and m.group(
+                    2
+                ), "unexpected format for expanding parameter"
+
+                tok = m.group(2).split("~~")
+                be_left, be_right = tok[1], tok[3]
+                replacement_expression = ", ".join(
+                    "%s%s%s"
+                    % (
+                        be_left,
+                        self.render_literal_value(value, parameter.type),
+                        be_right,
+                    )
+                    for value in values
+                )
+            else:
+                replacement_expression = ", ".join(
+                    self.render_literal_value(value, parameter.type)
+                    for value in values
+                )
 
         return (), replacement_expression
 
@@ -2453,7 +2484,7 @@ class SQLCompiler(Compiled):
                     bind_expression,
                     skip_bind_expression=True,
                     within_columns_clause=within_columns_clause,
-                    literal_binds=literal_binds,
+                    literal_binds=literal_binds and not bindparam.expanding,
                     literal_execute=literal_execute,
                     render_postcompile=render_postcompile,
                     **kwargs
@@ -2461,14 +2492,26 @@ class SQLCompiler(Compiled):
                 if bindparam.expanding:
                     # for postcompile w/ expanding, move the "wrapped" part
                     # of this into the inside
+
                     m = re.match(
                         r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
                     )
+                    assert m, "unexpected format for expanding parameter"
                     wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
                         m.group(2),
                         m.group(1),
                         m.group(3),
                     )
+
+                    if literal_binds:
+                        ret = self.render_literal_bindparam(
+                            bindparam,
+                            within_columns_clause=True,
+                            bind_expression_template=wrapped,
+                            **kwargs
+                        )
+                        return "(%s)" % ret
+
                 return wrapped
 
         if not literal_binds:
@@ -2568,7 +2611,11 @@ class SQLCompiler(Compiled):
         return ret
 
     def render_literal_bindparam(
-        self, bindparam, render_literal_value=NO_ARG, **kw
+        self,
+        bindparam,
+        render_literal_value=NO_ARG,
+        bind_expression_template=None,
+        **kw
     ):
         if render_literal_value is not NO_ARG:
             value = render_literal_value
@@ -2587,7 +2634,11 @@ class SQLCompiler(Compiled):
 
         if bindparam.expanding:
             leep = self._literal_execute_expanding_parameter_literal_binds
-            to_update, replacement_expr = leep(bindparam, value)
+            to_update, replacement_expr = leep(
+                bindparam,
+                value,
+                bind_expression_template=bind_expression_template,
+            )
             return replacement_expr
         else:
             return self.render_literal_value(value, bindparam.type)
index e0e0858a45093185ec7317108bd61646e4a06eb2..7c2192620795dc10f5032be97f495a7917303088 100644 (file)
@@ -182,28 +182,40 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL):
             "test_table WHERE test_table.y = lower(:y_1)",
         )
 
-    def test_in_binds(self):
+    @testing.variation(
+        "compile_opt", ["plain", "postcompile", "literal_binds"]
+    )
+    def test_in_binds(self, compile_opt):
         table = self._fixture()
 
-        self.assert_compile(
-            select(table).where(
-                table.c.y.in_(["hi", "there", "some", "expr"])
-            ),
-            "SELECT test_table.x, lower(test_table.y) AS y FROM "
-            "test_table WHERE test_table.y IN "
-            "(__[POSTCOMPILE_y_1~~lower(~~REPL~~)~~])",
-            render_postcompile=False,
+        stmt = select(table).where(
+            table.c.y.in_(["hi", "there", "some", "expr"])
         )
 
-        self.assert_compile(
-            select(table).where(
-                table.c.y.in_(["hi", "there", "some", "expr"])
-            ),
-            "SELECT test_table.x, lower(test_table.y) AS y FROM "
-            "test_table WHERE test_table.y IN "
-            "(lower(:y_1_1), lower(:y_1_2), lower(:y_1_3), lower(:y_1_4))",
-            render_postcompile=True,
-        )
+        if compile_opt.plain:
+            self.assert_compile(
+                stmt,
+                "SELECT test_table.x, lower(test_table.y) AS y FROM "
+                "test_table WHERE test_table.y IN "
+                "(__[POSTCOMPILE_y_1~~lower(~~REPL~~)~~])",
+                render_postcompile=False,
+            )
+        elif compile_opt.postcompile:
+            self.assert_compile(
+                stmt,
+                "SELECT test_table.x, lower(test_table.y) AS y FROM "
+                "test_table WHERE test_table.y IN "
+                "(lower(:y_1_1), lower(:y_1_2), lower(:y_1_3), lower(:y_1_4))",
+                render_postcompile=True,
+            )
+        elif compile_opt.literal_binds:
+            self.assert_compile(
+                stmt,
+                "SELECT test_table.x, lower(test_table.y) AS y FROM "
+                "test_table WHERE test_table.y IN "
+                "(lower('hi'), lower('there'), lower('some'), lower('expr'))",
+                literal_binds=True,
+            )
 
     def test_dialect(self):
         table = self._fixture()