]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix positional compiling bugs
authorFederico Caselli <cfederico87@gmail.com>
Sat, 19 Nov 2022 19:39:10 +0000 (20:39 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 1 Dec 2022 22:54:24 +0000 (23:54 +0100)
Fixed a series of issues regarding positionally rendered bound parameters,
such as those used for SQLite, asyncpg, MySQL and others. Some compiled
forms would not maintain the order of parameters correctly, such as the
PostgreSQL ``regexp_replace()`` function as well as within the "nesting"
feature of the :class:`.CTE` construct first introduced in :ticket:`4123`.

Fixes: #8827
Change-Id: I9813ed7c358cc5c1e26725c48df546b209a442cb
(cherry picked from commit 0f2baae6bf72353f785bad394684f2d6fa53e0ef)

doc/build/changelog/unreleased_14/8827.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/assertions.py
test/dialect/oracle/test_compiler.py
test/dialect/postgresql/test_compiler.py
test/sql/test_compiler.py
test/sql/test_cte.py
test/sql/test_external_traversal.py
test/sql/test_functions.py

diff --git a/doc/build/changelog/unreleased_14/8827.rst b/doc/build/changelog/unreleased_14/8827.rst
new file mode 100644 (file)
index 0000000..677277e
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 8827
+
+    Fixed a series of issues regarding positionally rendered bound parameters,
+    such as those used for SQLite, asyncpg, MySQL and others. Some compiled
+    forms would not maintain the order of parameters correctly, such as the
+    PostgreSQL ``regexp_replace()`` function as well as within the "nesting"
+    feature of the :class:`.CTE` construct first introduced in :ticket:`4123`.
index 77f0dbd2df6d20f069b7fcdc880808096349ccaf..417ab84b7b732ba4a8a2e0802f8e37ecc3b6fd9c 100644 (file)
@@ -941,7 +941,7 @@ class OracleCompiler(compiler.SQLCompiler):
     def visit_function(self, func, **kw):
         text = super(OracleCompiler, self).visit_function(func, **kw)
         if kw.get("asfrom", False):
-            text = "TABLE (%s)" % func
+            text = "TABLE (%s)" % text
         return text
 
     def visit_table_valued_column(self, element, **kw):
@@ -1270,20 +1270,18 @@ class OracleCompiler(compiler.SQLCompiler):
             self.process(binary.right),
         )
 
-    def _get_regexp_args(self, binary, kw):
+    def visit_regexp_match_op_binary(self, binary, operator, **kw):
         string = self.process(binary.left, **kw)
         pattern = self.process(binary.right, **kw)
         flags = binary.modifiers["flags"]
-        if flags is not None:
-            flags = self.process(flags, **kw)
-        return string, pattern, flags
-
-    def visit_regexp_match_op_binary(self, binary, operator, **kw):
-        string, pattern, flags = self._get_regexp_args(binary, kw)
         if flags is None:
             return "REGEXP_LIKE(%s, %s)" % (string, pattern)
         else:
-            return "REGEXP_LIKE(%s, %s, %s)" % (string, pattern, flags)
+            return "REGEXP_LIKE(%s, %s, %s)" % (
+                string,
+                pattern,
+                self.process(flags, **kw),
+            )
 
     def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
         return "NOT %s" % self.visit_regexp_match_op_binary(
@@ -1291,8 +1289,10 @@ class OracleCompiler(compiler.SQLCompiler):
         )
 
     def visit_regexp_replace_op_binary(self, binary, operator, **kw):
-        string, pattern, flags = self._get_regexp_args(binary, kw)
+        string = self.process(binary.left, **kw)
+        pattern = self.process(binary.right, **kw)
         replacement = self.process(binary.modifiers["replacement"], **kw)
+        flags = binary.modifiers["flags"]
         if flags is None:
             return "REGEXP_REPLACE(%s, %s, %s)" % (
                 string,
@@ -1304,7 +1304,7 @@ class OracleCompiler(compiler.SQLCompiler):
                 string,
                 pattern,
                 replacement,
-                flags,
+                self.process(flags, **kw),
             )
 
 
index c390553353a68ecb5bffa7c8b7aa17b1f81e56ea..9ad8379e26b7ce97a40b84e6a8184a8310cf4e55 100644 (file)
@@ -2386,14 +2386,11 @@ class PGCompiler(compiler.SQLCompiler):
             return self._generate_generic_binary(
                 binary, " %s* " % base_op, **kw
             )
-        flags = self.process(flags, **kw)
-        string = self.process(binary.left, **kw)
-        pattern = self.process(binary.right, **kw)
         return "%s %s CONCAT('(?', %s, ')', %s)" % (
-            string,
+            self.process(binary.left, **kw),
             base_op,
-            flags,
-            pattern,
+            self.process(flags, **kw),
+            self.process(binary.right, **kw),
         )
 
     def visit_regexp_match_op_binary(self, binary, operator, **kw):
@@ -2406,8 +2403,6 @@ class PGCompiler(compiler.SQLCompiler):
         string = self.process(binary.left, **kw)
         pattern = self.process(binary.right, **kw)
         flags = binary.modifiers["flags"]
-        if flags is not None:
-            flags = self.process(flags, **kw)
         replacement = self.process(binary.modifiers["replacement"], **kw)
         if flags is None:
             return "REGEXP_REPLACE(%s, %s, %s)" % (
@@ -2420,7 +2415,7 @@ class PGCompiler(compiler.SQLCompiler):
                 string,
                 pattern,
                 replacement,
-                flags,
+                self.process(flags, **kw),
             )
 
     def visit_empty_set_expr(self, element_types):
index 611cd182187bacfb7f850dce242804830adf8a65..8fbf3092aaff5d4efa72fd7ec8820ba95242dce9 100644 (file)
@@ -166,8 +166,8 @@ BIND_TEMPLATES = {
     "named": ":%(name)s",
 }
 
-_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
-_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
+_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]")
+_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___"))
 
 OPERATORS = {
     # binary
@@ -713,6 +713,7 @@ class SQLCompiler(Compiled):
         debugging use cases.
 
     """
+    positiontup_level = None
 
     inline = False
 
@@ -784,6 +785,7 @@ class SQLCompiler(Compiled):
         # true if the paramstyle is positional
         self.positional = dialect.positional
         if self.positional:
+            self.positiontup_level = {}
             self.positiontup = []
             self._numeric_binds = dialect.paramstyle == "numeric"
         self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
@@ -894,6 +896,8 @@ class SQLCompiler(Compiled):
         self.ctes_recursive = False
         if self.positional:
             self.cte_positional = {}
+            self.cte_level = {}
+            self.cte_order = collections.defaultdict(list)
 
     @contextlib.contextmanager
     def _nested_result(self):
@@ -1696,7 +1700,13 @@ class SQLCompiler(Compiled):
         text = self.process(taf.element, **kw)
         if self.ctes:
             nesting_level = len(self.stack) if not toplevel else None
-            text = self._render_cte_clause(nesting_level=nesting_level) + text
+            text = (
+                self._render_cte_clause(
+                    nesting_level=nesting_level,
+                    visiting_cte=kw.get("visiting_cte"),
+                )
+                + text
+            )
 
         self.stack.pop(-1)
 
@@ -1806,6 +1816,7 @@ class SQLCompiler(Compiled):
         )
 
     def visit_over(self, over, **kwargs):
+        text = over.element._compiler_dispatch(self, **kwargs)
         if over.range_:
             range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
                 over.range_, **kwargs
@@ -1818,7 +1829,7 @@ class SQLCompiler(Compiled):
             range_ = None
 
         return "%s OVER (%s)" % (
-            over.element._compiler_dispatch(self, **kwargs),
+            text,
             " ".join(
                 [
                     "%s BY %s"
@@ -1964,7 +1975,9 @@ class SQLCompiler(Compiled):
             nesting_level = len(self.stack) if not toplevel else None
             text = (
                 self._render_cte_clause(
-                    nesting_level=nesting_level, include_following_stack=True
+                    nesting_level=nesting_level,
+                    include_following_stack=True,
+                    visiting_cte=kwargs.get("visiting_cte"),
                 )
                 + text
             )
@@ -2667,7 +2680,8 @@ class SQLCompiler(Compiled):
                 positional_names.append(name)
             else:
                 self.positiontup.append(name)
-        elif not escaped_from:
+            self.positiontup_level[name] = len(self.stack)
+        if not escaped_from:
 
             if _BIND_TRANSLATE_RE.search(name):
                 # not quite the translate use case as we want to
@@ -2786,6 +2800,8 @@ class SQLCompiler(Compiled):
                         ]
                     }
                 )
+            if self.positional:
+                self.cte_level[cte] = cte_level
 
             if pre_alias_cte not in self.ctes:
                 self.visit_cte(pre_alias_cte, **kwargs)
@@ -3495,13 +3511,16 @@ class SQLCompiler(Compiled):
             if per_dialect:
                 text += " " + self.get_statement_hint_text(per_dialect)
 
-        if self.ctes:
-            # In compound query, CTEs are shared at the compound level
-            if not is_embedded_select:
-                nesting_level = len(self.stack) if not toplevel else None
-                text = (
-                    self._render_cte_clause(nesting_level=nesting_level) + text
+        # In compound query, CTEs are shared at the compound level
+        if self.ctes and (not is_embedded_select or toplevel):
+            nesting_level = len(self.stack) if not toplevel else None
+            text = (
+                self._render_cte_clause(
+                    nesting_level=nesting_level,
+                    visiting_cte=kwargs.get("visiting_cte"),
                 )
+                + text
+            )
 
         if select_stmt._suffixes:
             text += " " + self._generate_prefixes(
@@ -3677,6 +3696,7 @@ class SQLCompiler(Compiled):
         self,
         nesting_level=None,
         include_following_stack=False,
+        visiting_cte=None,
     ):
         """
         include_following_stack
@@ -3706,14 +3726,48 @@ class SQLCompiler(Compiled):
 
         if not ctes:
             return ""
-
         ctes_recursive = any([cte.recursive for cte in ctes])
 
         if self.positional:
-            self.positiontup = (
-                sum([self.cte_positional[cte] for cte in ctes], [])
-                + self.positiontup
-            )
+            self.cte_order[visiting_cte].extend(ctes)
+
+            if visiting_cte is None and self.cte_order:
+                assert self.positiontup is not None
+
+                def get_nested_positional(cte):
+                    if cte in self.cte_order:
+                        children = self.cte_order.pop(cte)
+                        to_add = list(
+                            itertools.chain.from_iterable(
+                                get_nested_positional(child_cte)
+                                for child_cte in children
+                            )
+                        )
+                        if cte in self.cte_positional:
+                            return reorder_positional(
+                                self.cte_positional[cte],
+                                to_add,
+                                self.cte_level[children[0]],
+                            )
+                        else:
+                            return to_add
+                    else:
+                        return self.cte_positional.get(cte, [])
+
+                def reorder_positional(pos, to_add, level):
+                    if not level:
+                        return to_add + pos
+                    index = 0
+                    for index, name in enumerate(reversed(pos)):
+                        if self.positiontup_level[name] < level:  # type: ignore[index] # noqa: E501
+                            break
+                    return pos[:-index] + to_add + pos[-index:]
+
+                to_add = get_nested_positional(None)
+                self.positiontup = reorder_positional(
+                    self.positiontup, to_add, nesting_level
+                )
+
         cte_text = self.get_cte_preamble(ctes_recursive) + " "
         cte_text += ", \n".join([txt for txt in ctes.values()])
         cte_text += "\n "
@@ -3985,6 +4039,7 @@ class SQLCompiler(Compiled):
                     self._render_cte_clause(
                         nesting_level=nesting_level,
                         include_following_stack=True,
+                        visiting_cte=kw.get("visiting_cte"),
                     ),
                     select_text,
                 )
@@ -4022,7 +4077,9 @@ class SQLCompiler(Compiled):
             nesting_level = len(self.stack) if not toplevel else None
             text = (
                 self._render_cte_clause(
-                    nesting_level=nesting_level, include_following_stack=True
+                    nesting_level=nesting_level,
+                    include_following_stack=True,
+                    visiting_cte=kw.get("visiting_cte"),
                 )
                 + text
             )
@@ -4162,7 +4219,13 @@ class SQLCompiler(Compiled):
 
         if self.ctes:
             nesting_level = len(self.stack) if not toplevel else None
-            text = self._render_cte_clause(nesting_level=nesting_level) + text
+            text = (
+                self._render_cte_clause(
+                    nesting_level=nesting_level,
+                    visiting_cte=kw.get("visiting_cte"),
+                )
+                + text
+            )
 
         self.stack.pop(-1)
 
@@ -4268,7 +4331,13 @@ class SQLCompiler(Compiled):
 
         if self.ctes:
             nesting_level = len(self.stack) if not toplevel else None
-            text = self._render_cte_clause(nesting_level=nesting_level) + text
+            text = (
+                self._render_cte_clause(
+                    nesting_level=nesting_level,
+                    visiting_cte=kw.get("visiting_cte"),
+                )
+                + text
+            )
 
         self.stack.pop(-1)
 
index ba6ee14c3b58ad41dfe68fd33714a0e65f0e0e81..9a022265eb10c8658a72f84ab6825fa94ea8cfcd 100644 (file)
@@ -7,7 +7,9 @@
 
 from __future__ import absolute_import
 
+from collections import defaultdict
 import contextlib
+from copy import copy
 import re
 import sys
 import warnings
@@ -499,6 +501,7 @@ class AssertsCompiledSQL(object):
         render_schema_translate=False,
         default_schema_name=None,
         from_linting=False,
+        check_param_order=True,
     ):
         if use_default_dialect:
             dialect = default.DefaultDialect()
@@ -512,8 +515,11 @@ class AssertsCompiledSQL(object):
 
             if dialect is None:
                 dialect = config.db.dialect
-            elif dialect == "default":
-                dialect = default.DefaultDialect()
+            elif dialect == "default" or dialect == "default_qmark":
+                if dialect == "default":
+                    dialect = default.DefaultDialect()
+                else:
+                    dialect = default.DefaultDialect(paramstyle="qmark")
                 dialect.supports_default_values = supports_default_values
                 dialect.supports_default_metavalue = supports_default_metavalue
             elif dialect == "default_enhanced":
@@ -645,7 +651,7 @@ class AssertsCompiledSQL(object):
         if checkparams is not None:
             eq_(c.construct_params(params), checkparams)
         if checkpositional is not None:
-            p = c.construct_params(params)
+            p = c.construct_params(params, escape_names=False)
             eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
         if check_prefetch is not None:
             eq_(c.prefetch, check_prefetch)
@@ -665,6 +671,58 @@ class AssertsCompiledSQL(object):
                 },
                 check_post_param,
             )
+        if check_param_order and getattr(c, "params", None):
+
+            def get_dialect(paramstyle, positional):
+                cp = copy(dialect)
+                cp.paramstyle = paramstyle
+                cp.positional = positional
+                return cp
+
+            pyformat_dialect = get_dialect("pyformat", False)
+            pyformat_c = clause.compile(dialect=pyformat_dialect, **kw)
+            stmt = re.sub(r"[\n\t]", "", pyformat_c.string)
+
+            qmark_dialect = get_dialect("qmark", True)
+            qmark_c = clause.compile(dialect=qmark_dialect, **kw)
+            values = list(qmark_c.positiontup)
+            escaped = qmark_c.escaped_bind_names
+
+            for post_param in (
+                qmark_c.post_compile_params | qmark_c.literal_execute_params
+            ):
+                name = qmark_c.bind_names[post_param]
+                if name in values:
+                    values = [v for v in values if v != name]
+            positions = []
+            pos_by_value = defaultdict(list)
+            for v in values:
+                try:
+                    if v in pos_by_value:
+                        start = pos_by_value[v][-1]
+                    else:
+                        start = 0
+                    esc = escaped.get(v, v)
+                    pos = stmt.index("%%(%s)s" % (esc,), start) + 2
+                    positions.append(pos)
+                    pos_by_value[v].append(pos)
+                except ValueError:
+                    msg = "Expected to find bindparam %r in %r" % (v, stmt)
+                    assert False, msg
+
+            ordered = all(
+                positions[i - 1] < positions[i]
+                for i in range(1, len(positions))
+            )
+
+            expected = [v for _, v in sorted(zip(positions, values))]
+
+            msg = (
+                "Order of parameters %s does not match the order "
+                "in the statement %s. Statement %r" % (values, expected, stmt)
+            )
+
+            is_true(ordered, msg)
 
 
 class ComparesTables(object):
index 8a8f51df0120dfbc9e9286319eb0264e30be506f..2c586990813666ae9281922e62b1625aa3abd36f 100644 (file)
@@ -1554,11 +1554,11 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             self.table.c.myid.regexp_replace(
                 "pattern", "replacement", flags="ig"
             ),
-            "REGEXP_REPLACE(mytable.myid, :myid_1, :myid_3, :myid_2)",
+            "REGEXP_REPLACE(mytable.myid, :myid_1, :myid_2, :myid_3)",
             checkparams={
                 "myid_1": "pattern",
-                "myid_3": "replacement",
-                "myid_2": "ig",
+                "myid_2": "replacement",
+                "myid_3": "ig",
             },
         )
 
index 0249c7952ce1aa91c3eaab3d43c995d3f199bd6b..e9de407c8e71e958612700e259ea2871f633bd62 100644 (file)
@@ -3277,11 +3277,11 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             self.table.c.myid.regexp_replace(
                 "pattern", "replacement", flags="ig"
             ),
-            "REGEXP_REPLACE(mytable.myid, %(myid_1)s, %(myid_3)s, %(myid_2)s)",
+            "REGEXP_REPLACE(mytable.myid, %(myid_1)s, %(myid_2)s, %(myid_3)s)",
             checkparams={
                 "myid_1": "pattern",
-                "myid_3": "replacement",
-                "myid_2": "ig",
+                "myid_2": "replacement",
+                "myid_3": "ig",
             },
         )
 
index 831ef1887203bcfaa37d3bc93b193ae3d74d2bc2..9ede4af9237b44e55028cc2177b53ffae2385129 100644 (file)
@@ -4794,6 +4794,118 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase):
                 stmt, expected, literal_binds=True, params=params
             )
 
+    standalone_escape = testing.combinations(
+        ("normalname", "normalname"),
+        ("_name", "_name"),
+        ("[BracketsAndCase]", "_BracketsAndCase_"),
+        ("has spaces", "has_spaces"),
+        argnames="paramname, expected",
+    )
+
+    @standalone_escape
+    @testing.variation("use_positional", [True, False])
+    def test_standalone_bindparam_escape(
+        self, paramname, expected, use_positional
+    ):
+        stmt = select(table1.c.myid).where(
+            table1.c.name == bindparam(paramname, value="x")
+        )
+        if use_positional:
+            self.assert_compile(
+                stmt,
+                "SELECT mytable.myid FROM mytable WHERE mytable.name = ?",
+                params={paramname: "y"},
+                checkpositional=("y",),
+                dialect="sqlite",
+            )
+        else:
+            self.assert_compile(
+                stmt,
+                "SELECT mytable.myid FROM mytable WHERE mytable.name = :%s"
+                % (expected,),
+                params={paramname: "y"},
+                checkparams={expected: "y"},
+                dialect="default",
+            )
+
+    @standalone_escape
+    @testing.variation("use_assert_compile", [True, False])
+    @testing.variation("use_positional", [True, False])
+    def test_standalone_bindparam_escape_expanding(
+        self, paramname, expected, use_assert_compile, use_positional
+    ):
+        stmt = select(table1.c.myid).where(
+            table1.c.name.in_(bindparam(paramname, value=["a", "b"]))
+        )
+        if use_assert_compile:
+            if use_positional:
+                self.assert_compile(
+                    stmt,
+                    "SELECT mytable.myid FROM mytable "
+                    "WHERE mytable.name IN (?, ?)",
+                    params={paramname: ["y", "z"]},
+                    # NOTE: this is what render_postcompile will do right now
+                    # if you run construct_params().  render_postcompile mode
+                    # is not actually used by the execution internals, it's for
+                    # user-facing compilation code.  So this is likely a
+                    # current limitation of construct_params() which is not
+                    # doing the full blown postcompile; just assert that's
+                    # what it does for now.  it likely should be corrected
+                    # to make more sense.
+                    checkpositional=(["y", "z"], ["y", "z"]),
+                    dialect="sqlite",
+                    render_postcompile=True,
+                )
+            else:
+                self.assert_compile(
+                    stmt,
+                    "SELECT mytable.myid FROM mytable WHERE mytable.name IN "
+                    "(:%s_1, :%s_2)" % (expected, expected),
+                    params={paramname: ["y", "z"]},
+                    # NOTE: this is what render_postcompile will do right now
+                    # if you run construct_params().  render_postcompile mode
+                    # is not actually used by the execution internals, it's for
+                    # user-facing compilation code.  So this is likely a
+                    # current limitation of construct_params() which is not
+                    # doing the full blown postcompile; just assert that's
+                    # what it does for now.  it likely should be corrected
+                    # to make more sense.
+                    checkparams={
+                        "%s_1" % expected: ["y", "z"],
+                        "%s_2" % expected: ["y", "z"],
+                    },
+                    dialect="default",
+                    render_postcompile=True,
+                )
+        else:
+            # this is what DefaultDialect actually does.
+            # this should be matched to DefaultDialect._init_compiled()
+            if use_positional:
+                compiled = stmt.compile(
+                    dialect=default.DefaultDialect(paramstyle="qmark")
+                )
+            else:
+                compiled = stmt.compile(dialect=default.DefaultDialect())
+            checkparams = compiled.construct_params(
+                {paramname: ["y", "z"]}, escape_names=False
+            )
+            # nothing actually happened.  if the compiler had
+            # render_postcompile set, the
+            # above weird param thing happens
+            eq_(checkparams, {paramname: ["y", "z"]})
+            expanded_state = compiled._process_parameters_for_postcompile(
+                checkparams
+            )
+            eq_(
+                expanded_state.additional_parameters,
+                {"%s_1" % (expected,): "y", "%s_2" % (expected,): "z"},
+            )
+            if use_positional:
+                eq_(
+                    expanded_state.positiontup,
+                    ["%s_1" % (expected,), "%s_2" % (expected,)],
+                )
+
 
 class UnsupportedTest(fixtures.TestBase):
     def test_unsupported_element_str_visit_name(self):
index fed371f62946d5613923945a99ae83eefe871cf9..40f92e41d01ee9dd248c0b4e7c42d7adbcda8918 100644 (file)
@@ -2095,7 +2095,8 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             ") SELECT cte.outer_cte FROM cte",
         )
 
-    def test_nesting_cte_in_recursive_cte(self):
+    @testing.fixture
+    def nesting_cte_in_recursive_cte(self):
         nesting_cte = select(literal(1).label("inner_cte")).cte(
             "nesting", nesting=True
         )
@@ -2104,20 +2105,85 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "rec_cte", recursive=True
         )
         rec_part = select(rec_cte.c.outer_cte).where(
-            rec_cte.c.outer_cte == literal(1)
+            rec_cte.c.outer_cte == literal(42)
         )
         rec_cte = rec_cte.union(rec_part)
 
         stmt = select(rec_cte)
+        return stmt
+
+    def test_nesting_cte_in_recursive_cte_positional(
+        self, nesting_cte_in_recursive_cte
+    ):
 
         self.assert_compile(
-            stmt,
+            nesting_cte_in_recursive_cte,
+            "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
+            "(SELECT ? AS inner_cte) "
+            "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+            "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+            "WHERE rec_cte.outer_cte = ?) "
+            "SELECT rec_cte.outer_cte FROM rec_cte",
+            checkpositional=(1, 42),
+            dialect="default_qmark",
+        )
+
+    def test_nesting_cte_in_recursive_cte(self, nesting_cte_in_recursive_cte):
+        self.assert_compile(
+            nesting_cte_in_recursive_cte,
+            "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
+            "(SELECT :param_1 AS inner_cte) "
+            "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+            "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+            "WHERE rec_cte.outer_cte = :param_2) "
+            "SELECT rec_cte.outer_cte FROM rec_cte",
+            checkparams={"param_1": 1, "param_2": 42},
+        )
+
+    @testing.fixture
+    def nesting_cte_in_recursive_cte_w_add_cte(self):
+        nesting_cte = select(literal(1).label("inner_cte")).cte(
+            "nesting", nesting=True
+        )
+
+        rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
+            "rec_cte", recursive=True
+        )
+        rec_part = select(rec_cte.c.outer_cte).where(
+            rec_cte.c.outer_cte == literal(42)
+        )
+        rec_cte = rec_cte.union(rec_part)
+
+        stmt = select(rec_cte)
+        return stmt
+
+    def test_nesting_cte_in_recursive_cte_w_add_cte_positional(
+        self, nesting_cte_in_recursive_cte_w_add_cte
+    ):
+        self.assert_compile(
+            nesting_cte_in_recursive_cte_w_add_cte,
+            "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
+            "(SELECT ? AS inner_cte) "
+            "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+            "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+            "WHERE rec_cte.outer_cte = ?) "
+            "SELECT rec_cte.outer_cte FROM rec_cte",
+            checkpositional=(1, 42),
+            dialect="default_qmark",
+        )
+
+    def test_nesting_cte_in_recursive_cte_w_add_cte(
+        self, nesting_cte_in_recursive_cte_w_add_cte
+    ):
+        self.assert_compile(
+            nesting_cte_in_recursive_cte_w_add_cte,
             "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
             "(SELECT :param_1 AS inner_cte) "
             "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
             "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
             "WHERE rec_cte.outer_cte = :param_2) "
             "SELECT rec_cte.outer_cte FROM rec_cte",
+            checkparams={"param_1": 1, "param_2": 42},
         )
 
     def test_recursive_nesting_cte_in_cte(self):
@@ -2219,18 +2285,19 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT cte.outer_cte FROM cte",
         )
 
-    def test_same_nested_cte_is_not_generated_twice(self):
+    @testing.fixture
+    def same_nested_cte_is_not_generated_twice(self):
         # Same = name and query
         nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
             "nesting_cte", nesting=True
         )
         select_add_cte = select(
-            (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value")
+            (nesting_cte_used_twice.c.inner_cte_1 + 2).label("next_value")
         ).cte("nesting_2", nesting=True)
 
         union_cte = (
             select(
-                (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value")
+                (nesting_cte_used_twice.c.inner_cte_1 - 3).label("next_value")
             )
             .union(select(select_add_cte))
             .cte("wrapper", nesting=True)
@@ -2241,9 +2308,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             .add_cte(nesting_cte_used_twice)
             .union(select(nesting_cte_used_twice))
         )
+        return stmt
 
+    def test_same_nested_cte_is_not_generated_twice_positional(
+        self, same_nested_cte_is_not_generated_twice
+    ):
         self.assert_compile(
-            stmt,
+            same_nested_cte_is_not_generated_twice,
+            "WITH nesting_cte AS "
+            "(SELECT ? AS inner_cte_1)"
+            ", wrapper AS "
+            "(WITH nesting_2 AS "
+            "(SELECT nesting_cte.inner_cte_1 + ? "
+            "AS next_value "
+            "FROM nesting_cte)"
+            " SELECT nesting_cte.inner_cte_1 - ? "
+            "AS next_value "
+            "FROM nesting_cte UNION SELECT nesting_2.next_value "
+            "AS next_value FROM nesting_2)"
+            " SELECT wrapper.next_value "
+            "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 "
+            "FROM nesting_cte",
+            checkpositional=(1, 2, 3),
+            dialect="default_qmark",
+        )
+
+    def test_same_nested_cte_is_not_generated_twice(
+        self, same_nested_cte_is_not_generated_twice
+    ):
+        self.assert_compile(
+            same_nested_cte_is_not_generated_twice,
             "WITH nesting_cte AS "
             "(SELECT :param_1 AS inner_cte_1)"
             ", wrapper AS "
@@ -2253,19 +2347,25 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "FROM nesting_cte)"
             " SELECT nesting_cte.inner_cte_1 - :inner_cte_1_1 "
             "AS next_value "
-            "FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value "
-            "FROM nesting_2)"
+            "FROM nesting_cte UNION SELECT nesting_2.next_value "
+            "AS next_value FROM nesting_2)"
             " SELECT wrapper.next_value "
             "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 "
             "FROM nesting_cte",
+            checkparams={
+                "param_1": 1,
+                "inner_cte_1_2": 2,
+                "inner_cte_1_1": 3,
+            },
         )
 
-    def test_recursive_nesting_cte_in_recursive_cte(self):
+    @testing.fixture
+    def recursive_nesting_cte_in_recursive_cte(self):
         nesting_cte = select(literal(1).label("inner_cte")).cte(
             "nesting", nesting=True, recursive=True
         )
         nesting_rec_part = select(nesting_cte.c.inner_cte).where(
-            nesting_cte.c.inner_cte == literal(1)
+            nesting_cte.c.inner_cte == literal(2)
         )
         nesting_cte = nesting_cte.union(nesting_rec_part)
 
@@ -2273,14 +2373,37 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "rec_cte", recursive=True
         )
         rec_part = select(rec_cte.c.outer_cte).where(
-            rec_cte.c.outer_cte == literal(1)
+            rec_cte.c.outer_cte == literal(3)
         )
         rec_cte = rec_cte.union(rec_part)
 
         stmt = select(rec_cte)
+        return stmt
+
+    def test_recursive_nesting_cte_in_recursive_cte_positional(
+        self, recursive_nesting_cte_in_recursive_cte
+    ):
 
         self.assert_compile(
-            stmt,
+            recursive_nesting_cte_in_recursive_cte,
+            "WITH RECURSIVE rec_cte(outer_cte) AS ("
+            "WITH RECURSIVE nesting(inner_cte) AS "
+            "(SELECT ? AS inner_cte UNION "
+            "SELECT nesting.inner_cte AS inner_cte FROM nesting "
+            "WHERE nesting.inner_cte = ?) "
+            "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+            "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+            "WHERE rec_cte.outer_cte = ?) "
+            "SELECT rec_cte.outer_cte FROM rec_cte",
+            checkpositional=(1, 2, 3),
+            dialect="default_qmark",
+        )
+
+    def test_recursive_nesting_cte_in_recursive_cte(
+        self, recursive_nesting_cte_in_recursive_cte
+    ):
+        self.assert_compile(
+            recursive_nesting_cte_in_recursive_cte,
             "WITH RECURSIVE rec_cte(outer_cte) AS ("
             "WITH RECURSIVE nesting(inner_cte) AS "
             "(SELECT :param_1 AS inner_cte UNION "
@@ -2290,6 +2413,7 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
             "WHERE rec_cte.outer_cte = :param_3) "
             "SELECT rec_cte.outer_cte FROM rec_cte",
+            checkparams={"param_1": 1, "param_2": 2, "param_3": 3},
         )
 
     def test_select_from_insert_cte_with_nesting(self):
@@ -2418,7 +2542,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             ") SELECT cte.outer_cte FROM cte",
         )
 
-    def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self):
+    @testing.fixture
+    def cte_in_compound_select(self):
+        upper = select(literal(1).label("z"))
+
+        lower_a_cte = select(literal(2).label("x")).cte("xx", nesting=True)
+        lower_a = select(literal(3).label("y")).add_cte(lower_a_cte)
+        lower_b = select(literal(4).label("w"))
+
+        stmt = upper.union_all(lower_a.union_all(lower_b))
+        return stmt
+
+    def test_cte_in_compound_select_positional(self, cte_in_compound_select):
+        self.assert_compile(
+            cte_in_compound_select,
+            "SELECT ? AS z UNION ALL (WITH xx AS "
+            "(SELECT ? AS x) "
+            "SELECT ? AS y UNION ALL SELECT ? AS w)",
+            checkpositional=(1, 2, 3, 4),
+            dialect="default_qmark",
+        )
+
+    def test_cte_in_compound_select(self, cte_in_compound_select):
+        self.assert_compile(
+            cte_in_compound_select,
+            "SELECT :param_1 AS z UNION ALL (WITH xx AS "
+            "(SELECT :param_2 AS x) "
+            "SELECT :param_3 AS y UNION ALL SELECT :param_4 AS w)",
+            checkparams={
+                "param_1": 1,
+                "param_2": 2,
+                "param_3": 3,
+                "param_4": 4,
+            },
+        )
+
+    @testing.fixture
+    def recursive_cte_referenced_multiple_times_with_nesting_cte(self):
         rec_root = select(literal(1).label("the_value")).cte(
             "recursive_cte", recursive=True
         )
@@ -2431,7 +2591,7 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             exists(
                 select(rec_root_ref.c.the_value)
                 .where(rec_root_ref.c.the_value < 10)
-                .limit(1)
+                .limit(5)
             ).label("val")
         ).cte("should_continue", nesting=True)
 
@@ -2447,13 +2607,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
         rec_cte = rec_root.union_all(rec_part)
 
         stmt = rec_cte.select()
+        return stmt
 
+    def test_recursive_cte_referenced_multiple_times_with_nesting_cte_pos(
+        self, recursive_cte_referenced_multiple_times_with_nesting_cte
+    ):
         self.assert_compile(
-            stmt,
+            recursive_cte_referenced_multiple_times_with_nesting_cte,
+            "WITH RECURSIVE recursive_cte(the_value) AS ("
+            "SELECT ? AS the_value UNION ALL ("
+            "WITH allow_multiple_ref AS ("
+            "SELECT recursive_cte.the_value AS the_value "
+            "FROM recursive_cte)"
+            ", should_continue AS (SELECT EXISTS ("
+            "SELECT allow_multiple_ref.the_value FROM allow_multiple_ref"
+            " WHERE allow_multiple_ref.the_value < ?"
+            " LIMIT ?) AS val) "
+            "SELECT allow_multiple_ref.the_value * ? AS anon_1"
+            " FROM allow_multiple_ref, should_continue "
+            "WHERE should_continue.val != 1"
+            " UNION ALL SELECT allow_multiple_ref.the_value * ?"
+            " AS anon_2 FROM allow_multiple_ref, should_continue"
+            " WHERE should_continue.val != 1))"
+            " SELECT recursive_cte.the_value FROM recursive_cte",
+            checkpositional=(1, 10, 5, 2, 3),
+            dialect="default_qmark",
+        )
+
+    def test_recursive_cte_referenced_multiple_times_with_nesting_cte(
+        self, recursive_cte_referenced_multiple_times_with_nesting_cte
+    ):
+        self.assert_compile(
+            recursive_cte_referenced_multiple_times_with_nesting_cte,
             "WITH RECURSIVE recursive_cte(the_value) AS ("
             "SELECT :param_1 AS the_value UNION ALL ("
             "WITH allow_multiple_ref AS ("
-            "SELECT recursive_cte.the_value AS the_value FROM recursive_cte)"
+            "SELECT recursive_cte.the_value AS the_value "
+            "FROM recursive_cte)"
             ", should_continue AS (SELECT EXISTS ("
             "SELECT allow_multiple_ref.the_value FROM allow_multiple_ref"
             " WHERE allow_multiple_ref.the_value < :the_value_2"
@@ -2465,4 +2655,11 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
             " AS anon_2 FROM allow_multiple_ref, should_continue"
             " WHERE should_continue.val != true))"
             " SELECT recursive_cte.the_value FROM recursive_cte",
+            checkparams={
+                "param_1": 1,
+                "param_2": 5,
+                "the_value_2": 10,
+                "the_value_1": 2,
+                "the_value_3": 3,
+            },
         )
index 1695771486a16f9f3010396e0427dfb705487554..37363273b20213cdddd36a3802ff50d47c6d364d 100644 (file)
@@ -193,7 +193,8 @@ class TraversalTest(
         ("name with~~tildes~~",),
         argnames="name",
     )
-    def test_bindparam_key_proc_for_copies(self, meth, name):
+    @testing.combinations(True, False, argnames="positional")
+    def test_bindparam_key_proc_for_copies(self, meth, name, positional):
         r"""test :ticket:`6249`.
 
         Revised for :ticket:`8056`.
@@ -225,13 +226,25 @@ class TraversalTest(
 
         token = re.sub(r"[%\(\) \$\[\]]", "_", name)
 
-        self.assert_compile(
-            expr,
-            '"%(name)s" IN (:%(token)s_1_1, '
-            ":%(token)s_1_2, :%(token)s_1_3)" % {"name": name, "token": token},
-            render_postcompile=True,
-            dialect="default",
-        )
+        if positional:
+            self.assert_compile(
+                expr,
+                '"%(name)s" IN (?, ?, ?)' % {"name": name},
+                checkpositional=(1, 2, 3),
+                render_postcompile=True,
+                dialect="default_qmark",
+            )
+        else:
+            tokens = ["%s_1_%s" % (token, i) for i in range(1, 4)]
+            self.assert_compile(
+                expr,
+                '"%(name)s" IN (:%(token)s_1_1, '
+                ":%(token)s_1_2, :%(token)s_1_3)"
+                % {"name": name, "token": token},
+                checkparams=dict(zip(tokens, [1, 2, 3])),
+                render_postcompile=True,
+                dialect="default",
+            )
 
     def test_expanding_in_bindparam_safe_to_clone(self):
         expr = column("x").in_([1, 2, 3])
index 3a9a06728cbd078f1bf65a3d1e7954e7f0cd0507..908fd9faaf058a24fd08ce06f6d1cd5a2c8a3f97 100644 (file)
@@ -777,6 +777,22 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "OVER (PARTITION BY mytable.description RANGE BETWEEN :param_1 "
             "FOLLOWING AND :param_2 FOLLOWING) "
             "AS anon_1 FROM mytable",
+            checkparams={"name_1": "foo", "param_1": 1, "param_2": 5},
+        )
+
+    def test_funcfilter_windowing_range_positional(self):
+        self.assert_compile(
+            select(
+                func.rank()
+                .filter(table1.c.name > "foo")
+                .over(range_=(1, 5), partition_by=["description"])
+            ),
+            "SELECT rank() FILTER (WHERE mytable.name > ?) "
+            "OVER (PARTITION BY mytable.description RANGE BETWEEN ? "
+            "FOLLOWING AND ? FOLLOWING) "
+            "AS anon_1 FROM mytable",
+            checkpositional=("foo", 1, 5),
+            dialect="default_qmark",
         )
 
     def test_funcfilter_windowing_rows(self):