]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Maintain compiled_params / replacement_expressions within expanding IN
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Dec 2018 22:35:12 +0000 (17:35 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 22 Dec 2018 02:56:45 +0000 (21:56 -0500)
Fixed issue in "expanding IN" feature where using the same bound parameter
name more than once in a query would lead to a KeyError within the process
of rewriting the parameters in the query.

Fixes: #4394
Change-Id: Ibcadce9fefbcb060266d9447c2044ee6efeccf5a
(cherry picked from commit c495769751e8b19d54fb92388ced587b5d13b85d)

doc/build/changelog/unreleased_12/4394.rst [new file with mode: 0644]
lib/sqlalchemy/engine/default.py
test/sql/test_query.py

diff --git a/doc/build/changelog/unreleased_12/4394.rst b/doc/build/changelog/unreleased_12/4394.rst
new file mode 100644 (file)
index 0000000..faa3547
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+   :tag: bug, sql
+   :tickets: 4394
+
+   Fixed issue in "expanding IN" feature where using the same bound parameter
+   name more than once in a query would lead to a KeyError within the process
+   of rewriting the parameters in the query.
index 915812a4f24a3bd5aceb09e1c530cb7ba6425caf..e63c5eafca7c6bb4f6d1cbe93d0330f805ff71dd 100644 (file)
@@ -726,45 +726,58 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
             positiontup = None
 
         replacement_expressions = {}
+        to_update_sets = {}
+
         for name in (
             self.compiled.positiontup if compiled.positional
             else self.compiled.binds
         ):
             parameter = self.compiled.binds[name]
             if parameter.expanding:
-                values = compiled_params.pop(name)
-                if not values:
-                    raise exc.InvalidRequestError(
-                        "'expanding' parameters can't be used with an "
-                        "empty list"
-                    )
-                elif isinstance(values[0], (tuple, list)):
-                    to_update = [
-                        ("%s_%s_%s" % (name, i, j), value)
-                        for i, tuple_element in enumerate(values, 1)
-                        for j, value in enumerate(tuple_element, 1)
-                    ]
-                    replacement_expressions[name] = ", ".join(
-                        "(%s)" % ", ".join(
+                if name in replacement_expressions:
+                    to_update = to_update_sets[name]
+                else:
+                    # we are removing the parameter from compiled_params
+                    # because it is a list value, which is not expected by
+                    # TypeEngine objects that would otherwise be asked to
+                    # process it. the single name is being replaced with
+                    # individual numbered parameters for each value in the
+                    # param.
+                    values = compiled_params.pop(name)
+
+                    if not values:
+                        raise exc.InvalidRequestError(
+                            "'expanding' parameters with an empty list not "
+                            "supported until SQLAlchemy 1.3."
+                        )
+                    elif isinstance(values[0], (tuple, list)):
+                        to_update = to_update_sets[name] = [
+                            ("%s_%s_%s" % (name, i, j), value)
+                            for i, tuple_element in enumerate(values, 1)
+                            for j, value in enumerate(tuple_element, 1)
+                        ]
+                        replacement_expressions[name] = ", ".join(
+                            "(%s)" % ", ".join(
+                                self.compiled.bindtemplate % {
+                                    "name":
+                                    to_update[i * len(tuple_element) + j][0]
+                                }
+                                for j, value in enumerate(tuple_element)
+                            )
+                            for i, tuple_element in enumerate(values)
+
+                        )
+                    else:
+                        to_update = to_update_sets[name] = [
+                            ("%s_%s" % (name, i), value)
+                            for i, value in enumerate(values, 1)
+                        ]
+                        replacement_expressions[name] = ", ".join(
                             self.compiled.bindtemplate % {
-                                "name":
-                                to_update[i * len(tuple_element) + j][0]
-                            }
-                            for j, value in enumerate(tuple_element)
+                                "name": key}
+                            for key, value in to_update
                         )
-                        for i, tuple_element in enumerate(values)
 
-                    )
-                else:
-                    to_update = [
-                        ("%s_%s" % (name, i), value)
-                        for i, value in enumerate(values, 1)
-                    ]
-                    replacement_expressions[name] = ", ".join(
-                        self.compiled.bindtemplate % {
-                            "name": key}
-                        for key, value in to_update
-                    )
                 compiled_params.update(to_update)
                 processors.update(
                     (key, processors[name])
@@ -778,7 +791,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                 positiontup.append(name)
 
         def process_expanding(m):
-            return replacement_expressions.pop(m.group(1))
+            return replacement_expressions[m.group(1)]
 
         self.statement = re.sub(
             r"\[EXPANDING_(\S+)\]",
index 3e629fb261f2353e408a8ed5b7b1dc0f0d45de2c..d649da202ad53d10c9b1c6ef600f652b20552edc 100644 (file)
@@ -457,7 +457,7 @@ class QueryTest(fixtures.TestBase):
 
             assert_raises_message(
                 exc.StatementError,
-                "'expanding' parameters can't be used with an empty list",
+                "'expanding' parameters with an empty list not supported",
                 conn.execute,
                 stmt, {"uname": []}
             )
@@ -531,6 +531,41 @@ class QueryTest(fixtures.TestBase):
                 [(8, 'fred'), (9, 'ed')]
             )
 
+    def test_expanding_in_repeated(self):
+        testing.db.execute(
+            users.insert(),
+            [
+                dict(user_id=7, user_name='jack'),
+                dict(user_id=8, user_name='fred'),
+                dict(user_id=9, user_name='ed')
+            ]
+        )
+
+        with testing.db.connect() as conn:
+            stmt = select([users]).where(
+                users.c.user_name.in_(
+                    bindparam('uname', expanding=True)
+                ) | users.c.user_name.in_(bindparam('uname2', expanding=True))
+            ).where(users.c.user_id == 8)
+            stmt = stmt.union(
+                select([users]).where(
+                    users.c.user_name.in_(
+                        bindparam('uname', expanding=True)
+                    ) | users.c.user_name.in_(
+                        bindparam('uname2', expanding=True))
+                ).where(users.c.user_id == 9)
+            ).order_by(stmt.c.user_id)
+
+            eq_(
+                conn.execute(
+                    stmt,
+                    {
+                        "uname": ['jack', 'fred'],
+                        "uname2": ['ed'], "userid": [8, 9]}
+                ).fetchall(),
+                [(8, 'fred'), (9, 'ed')]
+            )
+
     @testing.requires.tuple_in
     def test_expanding_in_composite(self):
         testing.db.execute(