]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Maintain compiled_params / replacement_expressions within expanding IN
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Dec 2018 22:35:12 +0000 (17:35 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Dec 2018 02:50:55 +0000 (21:50 -0500)
Fixed issue in "expanding IN" feature where using the same bound parameter
name more than once in a query would lead to a KeyError within the process
of rewriting the parameters in the query.

Fixes: #4394
Change-Id: Ibcadce9fefbcb060266d9447c2044ee6efeccf5a

doc/build/changelog/unreleased_12/4394.rst [new file with mode: 0644]
lib/sqlalchemy/engine/default.py
test/sql/test_query.py

diff --git a/doc/build/changelog/unreleased_12/4394.rst b/doc/build/changelog/unreleased_12/4394.rst
new file mode 100644 (file)
index 0000000..faa3547
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+   :tag: bug, sql
+   :tickets: 4394
+
+   Fixed issue in "expanding IN" feature where using the same bound parameter
+   name more than once in a query would lead to a KeyError within the process
+   of rewriting the parameters in the query.
index 5c96e4240e1ead003db6bf7d344449598e8c3720..028abc4c24b10d6cba3296425be0b6b4f8d82ace 100644 (file)
@@ -726,50 +726,64 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             positiontup = None
 
         replacement_expressions = {}
+        to_update_sets = {}
+
         for name in (
             self.compiled.positiontup if compiled.positional
             else self.compiled.binds
         ):
             parameter = self.compiled.binds[name]
             if parameter.expanding:
-                values = compiled_params.pop(name)
-                if not values:
-                    to_update = []
-                    replacement_expressions[name] = (
-                        self.compiled.visit_empty_set_expr(
-                            parameter._expanding_in_types
-                            if parameter._expanding_in_types
-                            else [parameter.type]
+
+                if name in replacement_expressions:
+                    to_update = to_update_sets[name]
+                else:
+                    # we are removing the parameter from compiled_params
+                    # because it is a list value, which is not expected by
+                    # TypeEngine objects that would otherwise be asked to
+                    # process it. the single name is being replaced with
+                    # individual numbered parameters for each value in the
+                    # param.
+                    values = compiled_params.pop(name)
+
+                    if not values:
+                        to_update = to_update_sets[name] = []
+                        replacement_expressions[name] = (
+                            self.compiled.visit_empty_set_expr(
+                                parameter._expanding_in_types
+                                if parameter._expanding_in_types
+                                else [parameter.type]
+                            )
                         )
-                    )
 
-                elif isinstance(values[0], (tuple, list)):
-                    to_update = [
-                        ("%s_%s_%s" % (name, i, j), value)
-                        for i, tuple_element in enumerate(values, 1)
-                        for j, value in enumerate(tuple_element, 1)
-                    ]
-                    replacement_expressions[name] = ", ".join(
-                        "(%s)" % ", ".join(
+                    elif isinstance(values[0], (tuple, list)):
+                        to_update = to_update_sets[name] = [
+                            ("%s_%s_%s" % (name, i, j), value)
+                            for i, tuple_element in enumerate(values, 1)
+                            for j, value in enumerate(tuple_element, 1)
+                        ]
+                        replacement_expressions[name] = ", ".join(
+                            "(%s)" % ", ".join(
+                                self.compiled.bindtemplate % {
+                                    "name":
+                                    to_update[i * len(tuple_element) + j][0]
+                                }
+                                for j, value in enumerate(tuple_element)
+                            )
+                            for i, tuple_element in enumerate(values)
+
+                        )
+                    else:
+                        to_update = to_update_sets[name] = [
+                            ("%s_%s" % (name, i), value)
+                            for i, value in enumerate(values, 1)
+                        ]
+                        replacement_expressions[name] = ", ".join(
                             self.compiled.bindtemplate % {
-                                "name":
-                                to_update[i * len(tuple_element) + j][0]
-                            }
-                            for j, value in enumerate(tuple_element)
+                                "name": key}
+                            for key, value in to_update
                         )
-                        for i, tuple_element in enumerate(values)
 
-                    )
-                else:
-                    to_update = [
-                        ("%s_%s" % (name, i), value)
-                        for i, value in enumerate(values, 1)
-                    ]
-                    replacement_expressions[name] = ", ".join(
-                        self.compiled.bindtemplate % {
-                            "name": key}
-                        for key, value in to_update
-                    )
                 compiled_params.update(to_update)
                 processors.update(
                     (key, processors[name])
@@ -783,7 +797,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                 positiontup.append(name)
 
         def process_expanding(m):
-            return replacement_expressions.pop(m.group(1))
+            return replacement_expressions[m.group(1)]
 
         self.statement = re.sub(
             r"\[EXPANDING_(\S+)\]",
index 971374eb92ea5d0540fc004ade9687f4b3d7e5ac..175b69c4f216c366a5d8f379d79faeb7afb0449d 100644 (file)
@@ -529,6 +529,41 @@ class QueryTest(fixtures.TestBase):
                 [(8, 'fred'), (9, 'ed')]
             )
 
+    def test_expanding_in_repeated(self):
+        testing.db.execute(
+            users.insert(),
+            [
+                dict(user_id=7, user_name='jack'),
+                dict(user_id=8, user_name='fred'),
+                dict(user_id=9, user_name='ed')
+            ]
+        )
+
+        with testing.db.connect() as conn:
+            stmt = select([users]).where(
+                users.c.user_name.in_(
+                    bindparam('uname', expanding=True)
+                ) | users.c.user_name.in_(bindparam('uname2', expanding=True))
+            ).where(users.c.user_id == 8)
+            stmt = stmt.union(
+                select([users]).where(
+                    users.c.user_name.in_(
+                        bindparam('uname', expanding=True)
+                    ) | users.c.user_name.in_(
+                        bindparam('uname2', expanding=True))
+                ).where(users.c.user_id == 9)
+            ).order_by(stmt.c.user_id)
+
+            eq_(
+                conn.execute(
+                    stmt,
+                    {
+                        "uname": ['jack', 'fred'],
+                        "uname2": ['ed'], "userid": [8, 9]}
+                ).fetchall(),
+                [(8, 'fred'), (9, 'ed')]
+            )
+
     @testing.requires.tuple_in
     def test_expanding_in_composite(self):
         testing.db.execute(