From: Federico Caselli Date: Sat, 19 Nov 2022 19:39:10 +0000 (+0100) Subject: Fix positional compiling bugs X-Git-Tag: rel_1_4_45~14^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=55ee628e9ef8e0e90786bbb550b124cf4b634f8a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fix positional compiling bugs Fixed a series of issues regarding positionally rendered bound parameters, such as those used for SQLite, asyncpg, MySQL and others. Some compiled forms would not maintain the order of parameters correctly, such as the PostgreSQL ``regexp_replace()`` function as well as within the "nesting" feature of the :class:`.CTE` construct first introduced in :ticket:`4123`. Fixes: #8827 Change-Id: I9813ed7c358cc5c1e26725c48df546b209a442cb (cherry picked from commit 0f2baae6bf72353f785bad394684f2d6fa53e0ef) --- diff --git a/doc/build/changelog/unreleased_14/8827.rst b/doc/build/changelog/unreleased_14/8827.rst new file mode 100644 index 0000000000..677277e45d --- /dev/null +++ b/doc/build/changelog/unreleased_14/8827.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, sql + :tickets: 8827 + + Fixed a series of issues regarding positionally rendered bound parameters, + such as those used for SQLite, asyncpg, MySQL and others. Some compiled + forms would not maintain the order of parameters correctly, such as the + PostgreSQL ``regexp_replace()`` function as well as within the "nesting" + feature of the :class:`.CTE` construct first introduced in :ticket:`4123`. diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 77f0dbd2df..417ab84b7b 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -941,7 +941,7 @@ class OracleCompiler(compiler.SQLCompiler): def visit_function(self, func, **kw): text = super(OracleCompiler, self).visit_function(func, **kw) if kw.get("asfrom", False): - text = "TABLE (%s)" % func + text = "TABLE (%s)" % text return text def visit_table_valued_column(self, element, **kw): @@ -1270,20 +1270,18 @@ class OracleCompiler(compiler.SQLCompiler): self.process(binary.right), ) - def _get_regexp_args(self, binary, kw): + def visit_regexp_match_op_binary(self, binary, operator, **kw): string = self.process(binary.left, **kw) pattern = self.process(binary.right, **kw) flags = binary.modifiers["flags"] - if flags is not None: - flags = self.process(flags, **kw) - return string, pattern, flags - - def visit_regexp_match_op_binary(self, binary, operator, **kw): - string, pattern, flags = self._get_regexp_args(binary, kw) if flags is None: return "REGEXP_LIKE(%s, %s)" % (string, pattern) else: - return "REGEXP_LIKE(%s, %s, %s)" % (string, pattern, flags) + return "REGEXP_LIKE(%s, %s, %s)" % ( + string, + pattern, + self.process(flags, **kw), + ) def visit_not_regexp_match_op_binary(self, binary, operator, **kw): return "NOT %s" % self.visit_regexp_match_op_binary( @@ -1291,8 +1289,10 @@ class OracleCompiler(compiler.SQLCompiler): ) def visit_regexp_replace_op_binary(self, binary, operator, **kw): - string, pattern, flags = self._get_regexp_args(binary, kw) + string = self.process(binary.left, **kw) + pattern = self.process(binary.right, **kw) replacement = self.process(binary.modifiers["replacement"], **kw) + flags = binary.modifiers["flags"] if flags is None: return "REGEXP_REPLACE(%s, %s, %s)" % ( string, @@ -1304,7 +1304,7 @@ class OracleCompiler(compiler.SQLCompiler): string, pattern, replacement, - flags, + self.process(flags, **kw), ) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index c390553353..9ad8379e26 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2386,14 +2386,11 @@ class PGCompiler(compiler.SQLCompiler): return self._generate_generic_binary( binary, " %s* " % base_op, **kw ) - flags = self.process(flags, **kw) - string = self.process(binary.left, **kw) - pattern = self.process(binary.right, **kw) return "%s %s CONCAT('(?', %s, ')', %s)" % ( - string, + self.process(binary.left, **kw), base_op, - flags, - pattern, + self.process(flags, **kw), + self.process(binary.right, **kw), ) def visit_regexp_match_op_binary(self, binary, operator, **kw): @@ -2406,8 +2403,6 @@ class PGCompiler(compiler.SQLCompiler): string = self.process(binary.left, **kw) pattern = self.process(binary.right, **kw) flags = binary.modifiers["flags"] - if flags is not None: - flags = self.process(flags, **kw) replacement = self.process(binary.modifiers["replacement"], **kw) if flags is None: return "REGEXP_REPLACE(%s, %s, %s)" % ( @@ -2420,7 +2415,7 @@ class PGCompiler(compiler.SQLCompiler): string, pattern, replacement, - flags, + self.process(flags, **kw), ) def visit_empty_set_expr(self, element_types): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 611cd18218..8fbf3092aa 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -166,8 +166,8 @@ BIND_TEMPLATES = { "named": ":%(name)s", } -_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]") -_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__")) +_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]") +_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___")) OPERATORS = { # binary @@ -713,6 +713,7 @@ class SQLCompiler(Compiled): debugging use cases. """ + positiontup_level = None inline = False @@ -784,6 +785,7 @@ class SQLCompiler(Compiled): # true if the paramstyle is positional self.positional = dialect.positional if self.positional: + self.positiontup_level = {} self.positiontup = [] self._numeric_binds = dialect.paramstyle == "numeric" self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] @@ -894,6 +896,8 @@ class SQLCompiler(Compiled): self.ctes_recursive = False if self.positional: self.cte_positional = {} + self.cte_level = {} + self.cte_order = collections.defaultdict(list) @contextlib.contextmanager def _nested_result(self): @@ -1696,7 +1700,13 @@ class SQLCompiler(Compiled): text = self.process(taf.element, **kw) if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) @@ -1806,6 +1816,7 @@ class SQLCompiler(Compiled): ) def visit_over(self, over, **kwargs): + text = over.element._compiler_dispatch(self, **kwargs) if over.range_: range_ = "RANGE BETWEEN %s" % self._format_frame_clause( over.range_, **kwargs @@ -1818,7 +1829,7 @@ class SQLCompiler(Compiled): range_ = None return "%s OVER (%s)" % ( - over.element._compiler_dispatch(self, **kwargs), + text, " ".join( [ "%s BY %s" @@ -1964,7 +1975,9 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = ( self._render_cte_clause( - nesting_level=nesting_level, include_following_stack=True + nesting_level=nesting_level, + include_following_stack=True, + visiting_cte=kwargs.get("visiting_cte"), ) + text ) @@ -2667,7 +2680,8 @@ class SQLCompiler(Compiled): positional_names.append(name) else: self.positiontup.append(name) - elif not escaped_from: + self.positiontup_level[name] = len(self.stack) + if not escaped_from: if _BIND_TRANSLATE_RE.search(name): # not quite the translate use case as we want to @@ -2786,6 +2800,8 @@ class SQLCompiler(Compiled): ] } ) + if self.positional: + self.cte_level[cte] = cte_level if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -3495,13 +3511,16 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes: - # In compound query, CTEs are shared at the compound level - if not is_embedded_select: - nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause(nesting_level=nesting_level) + text + # In compound query, CTEs are shared at the compound level + if self.ctes and (not is_embedded_select or toplevel): + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kwargs.get("visiting_cte"), ) + + text + ) if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -3677,6 +3696,7 @@ class SQLCompiler(Compiled): self, nesting_level=None, include_following_stack=False, + visiting_cte=None, ): """ include_following_stack @@ -3706,14 +3726,48 @@ class SQLCompiler(Compiled): if not ctes: return "" - ctes_recursive = any([cte.recursive for cte in ctes]) if self.positional: - self.positiontup = ( - sum([self.cte_positional[cte] for cte in ctes], []) - + self.positiontup - ) + self.cte_order[visiting_cte].extend(ctes) + + if visiting_cte is None and self.cte_order: + assert self.positiontup is not None + + def get_nested_positional(cte): + if cte in self.cte_order: + children = self.cte_order.pop(cte) + to_add = list( + itertools.chain.from_iterable( + get_nested_positional(child_cte) + for child_cte in children + ) + ) + if cte in self.cte_positional: + return reorder_positional( + self.cte_positional[cte], + to_add, + self.cte_level[children[0]], + ) + else: + return to_add + else: + return self.cte_positional.get(cte, []) + + def reorder_positional(pos, to_add, level): + if not level: + return to_add + pos + index = 0 + for index, name in enumerate(reversed(pos)): + if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501 + break + return pos[:-index] + to_add + pos[-index:] + + to_add = get_nested_positional(None) + self.positiontup = reorder_positional( + self.positiontup, to_add, nesting_level + ) + cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) cte_text += "\n " @@ -3985,6 +4039,7 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, + visiting_cte=kw.get("visiting_cte"), ), select_text, ) @@ -4022,7 +4077,9 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = ( self._render_cte_clause( - nesting_level=nesting_level, include_following_stack=True + nesting_level=nesting_level, + include_following_stack=True, + visiting_cte=kw.get("visiting_cte"), ) + text ) @@ -4162,7 +4219,13 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) @@ -4268,7 +4331,13 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index ba6ee14c3b..9a022265eb 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -7,7 +7,9 @@ from __future__ import absolute_import +from collections import defaultdict import contextlib +from copy import copy import re import sys import warnings @@ -499,6 +501,7 @@ class AssertsCompiledSQL(object): render_schema_translate=False, default_schema_name=None, from_linting=False, + check_param_order=True, ): if use_default_dialect: dialect = default.DefaultDialect() @@ -512,8 +515,11 @@ class AssertsCompiledSQL(object): if dialect is None: dialect = config.db.dialect - elif dialect == "default": - dialect = default.DefaultDialect() + elif dialect == "default" or dialect == "default_qmark": + if dialect == "default": + dialect = default.DefaultDialect() + else: + dialect = default.DefaultDialect(paramstyle="qmark") dialect.supports_default_values = supports_default_values dialect.supports_default_metavalue = supports_default_metavalue elif dialect == "default_enhanced": @@ -645,7 +651,7 @@ class AssertsCompiledSQL(object): if checkparams is not None: eq_(c.construct_params(params), checkparams) if checkpositional is not None: - p = c.construct_params(params) + p = c.construct_params(params, escape_names=False) eq_(tuple([p[x] for x in c.positiontup]), checkpositional) if check_prefetch is not None: eq_(c.prefetch, check_prefetch) @@ -665,6 +671,58 @@ class AssertsCompiledSQL(object): }, check_post_param, ) + if check_param_order and getattr(c, "params", None): + + def get_dialect(paramstyle, positional): + cp = copy(dialect) + cp.paramstyle = paramstyle + cp.positional = positional + return cp + + pyformat_dialect = get_dialect("pyformat", False) + pyformat_c = clause.compile(dialect=pyformat_dialect, **kw) + stmt = re.sub(r"[\n\t]", "", pyformat_c.string) + + qmark_dialect = get_dialect("qmark", True) + qmark_c = clause.compile(dialect=qmark_dialect, **kw) + values = list(qmark_c.positiontup) + escaped = qmark_c.escaped_bind_names + + for post_param in ( + qmark_c.post_compile_params | qmark_c.literal_execute_params + ): + name = qmark_c.bind_names[post_param] + if name in values: + values = [v for v in values if v != name] + positions = [] + pos_by_value = defaultdict(list) + for v in values: + try: + if v in pos_by_value: + start = pos_by_value[v][-1] + else: + start = 0 + esc = escaped.get(v, v) + pos = stmt.index("%%(%s)s" % (esc,), start) + 2 + positions.append(pos) + pos_by_value[v].append(pos) + except ValueError: + msg = "Expected to find bindparam %r in %r" % (v, stmt) + assert False, msg + + ordered = all( + positions[i - 1] < positions[i] + for i in range(1, len(positions)) + ) + + expected = [v for _, v in sorted(zip(positions, values))] + + msg = ( + "Order of parameters %s does not match the order " + "in the statement %s. Statement %r" % (values, expected, stmt) + ) + + is_true(ordered, msg) class ComparesTables(object): diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index 8a8f51df01..2c58699081 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -1554,11 +1554,11 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.table.c.myid.regexp_replace( "pattern", "replacement", flags="ig" ), - "REGEXP_REPLACE(mytable.myid, :myid_1, :myid_3, :myid_2)", + "REGEXP_REPLACE(mytable.myid, :myid_1, :myid_2, :myid_3)", checkparams={ "myid_1": "pattern", - "myid_3": "replacement", - "myid_2": "ig", + "myid_2": "replacement", + "myid_3": "ig", }, ) diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 0249c7952c..e9de407c8e 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -3277,11 +3277,11 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.table.c.myid.regexp_replace( "pattern", "replacement", flags="ig" ), - "REGEXP_REPLACE(mytable.myid, %(myid_1)s, %(myid_3)s, %(myid_2)s)", + "REGEXP_REPLACE(mytable.myid, %(myid_1)s, %(myid_2)s, %(myid_3)s)", checkparams={ "myid_1": "pattern", - "myid_3": "replacement", - "myid_2": "ig", + "myid_2": "replacement", + "myid_3": "ig", }, ) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 831ef18872..9ede4af923 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -4794,6 +4794,118 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): stmt, expected, literal_binds=True, params=params ) + standalone_escape = testing.combinations( + ("normalname", "normalname"), + ("_name", "_name"), + ("[BracketsAndCase]", "_BracketsAndCase_"), + ("has spaces", "has_spaces"), + argnames="paramname, expected", + ) + + @standalone_escape + @testing.variation("use_positional", [True, False]) + def test_standalone_bindparam_escape( + self, paramname, expected, use_positional + ): + stmt = select(table1.c.myid).where( + table1.c.name == bindparam(paramname, value="x") + ) + if use_positional: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable WHERE mytable.name = ?", + params={paramname: "y"}, + checkpositional=("y",), + dialect="sqlite", + ) + else: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable WHERE mytable.name = :%s" + % (expected,), + params={paramname: "y"}, + checkparams={expected: "y"}, + dialect="default", + ) + + @standalone_escape + @testing.variation("use_assert_compile", [True, False]) + @testing.variation("use_positional", [True, False]) + def test_standalone_bindparam_escape_expanding( + self, paramname, expected, use_assert_compile, use_positional + ): + stmt = select(table1.c.myid).where( + table1.c.name.in_(bindparam(paramname, value=["a", "b"])) + ) + if use_assert_compile: + if use_positional: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable " + "WHERE mytable.name IN (?, ?)", + params={paramname: ["y", "z"]}, + # NOTE: this is what render_postcompile will do right now + # if you run construct_params(). render_postcompile mode + # is not actually used by the execution internals, it's for + # user-facing compilation code. So this is likely a + # current limitation of construct_params() which is not + # doing the full blown postcompile; just assert that's + # what it does for now. it likely should be corrected + # to make more sense. + checkpositional=(["y", "z"], ["y", "z"]), + dialect="sqlite", + render_postcompile=True, + ) + else: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable WHERE mytable.name IN " + "(:%s_1, :%s_2)" % (expected, expected), + params={paramname: ["y", "z"]}, + # NOTE: this is what render_postcompile will do right now + # if you run construct_params(). render_postcompile mode + # is not actually used by the execution internals, it's for + # user-facing compilation code. So this is likely a + # current limitation of construct_params() which is not + # doing the full blown postcompile; just assert that's + # what it does for now. it likely should be corrected + # to make more sense. + checkparams={ + "%s_1" % expected: ["y", "z"], + "%s_2" % expected: ["y", "z"], + }, + dialect="default", + render_postcompile=True, + ) + else: + # this is what DefaultDialect actually does. + # this should be matched to DefaultDialect._init_compiled() + if use_positional: + compiled = stmt.compile( + dialect=default.DefaultDialect(paramstyle="qmark") + ) + else: + compiled = stmt.compile(dialect=default.DefaultDialect()) + checkparams = compiled.construct_params( + {paramname: ["y", "z"]}, escape_names=False + ) + # nothing actually happened. if the compiler had + # render_postcompile set, the + # above weird param thing happens + eq_(checkparams, {paramname: ["y", "z"]}) + expanded_state = compiled._process_parameters_for_postcompile( + checkparams + ) + eq_( + expanded_state.additional_parameters, + {"%s_1" % (expected,): "y", "%s_2" % (expected,): "z"}, + ) + if use_positional: + eq_( + expanded_state.positiontup, + ["%s_1" % (expected,), "%s_2" % (expected,)], + ) + class UnsupportedTest(fixtures.TestBase): def test_unsupported_element_str_visit_name(self): diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index fed371f629..40f92e41d0 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -2095,7 +2095,8 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): ") SELECT cte.outer_cte FROM cte", ) - def test_nesting_cte_in_recursive_cte(self): + @testing.fixture + def nesting_cte_in_recursive_cte(self): nesting_cte = select(literal(1).label("inner_cte")).cte( "nesting", nesting=True ) @@ -2104,20 +2105,85 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "rec_cte", recursive=True ) rec_part = select(rec_cte.c.outer_cte).where( - rec_cte.c.outer_cte == literal(1) + rec_cte.c.outer_cte == literal(42) ) rec_cte = rec_cte.union(rec_part) stmt = select(rec_cte) + return stmt + + def test_nesting_cte_in_recursive_cte_positional( + self, nesting_cte_in_recursive_cte + ): self.assert_compile( - stmt, + nesting_cte_in_recursive_cte, + "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS " + "(SELECT ? AS inner_cte) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " + "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " + "WHERE rec_cte.outer_cte = ?) " + "SELECT rec_cte.outer_cte FROM rec_cte", + checkpositional=(1, 42), + dialect="default_qmark", + ) + + def test_nesting_cte_in_recursive_cte(self, nesting_cte_in_recursive_cte): + self.assert_compile( + nesting_cte_in_recursive_cte, + "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS " + "(SELECT :param_1 AS inner_cte) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " + "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " + "WHERE rec_cte.outer_cte = :param_2) " + "SELECT rec_cte.outer_cte FROM rec_cte", + checkparams={"param_1": 1, "param_2": 42}, + ) + + @testing.fixture + def nesting_cte_in_recursive_cte_w_add_cte(self): + nesting_cte = select(literal(1).label("inner_cte")).cte( + "nesting", nesting=True + ) + + rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte( + "rec_cte", recursive=True + ) + rec_part = select(rec_cte.c.outer_cte).where( + rec_cte.c.outer_cte == literal(42) + ) + rec_cte = rec_cte.union(rec_part) + + stmt = select(rec_cte) + return stmt + + def test_nesting_cte_in_recursive_cte_w_add_cte_positional( + self, nesting_cte_in_recursive_cte_w_add_cte + ): + self.assert_compile( + nesting_cte_in_recursive_cte_w_add_cte, + "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS " + "(SELECT ? AS inner_cte) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " + "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " + "WHERE rec_cte.outer_cte = ?) " + "SELECT rec_cte.outer_cte FROM rec_cte", + checkpositional=(1, 42), + dialect="default_qmark", + ) + + def test_nesting_cte_in_recursive_cte_w_add_cte( + self, nesting_cte_in_recursive_cte_w_add_cte + ): + self.assert_compile( + nesting_cte_in_recursive_cte_w_add_cte, "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS " "(SELECT :param_1 AS inner_cte) " "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " "WHERE rec_cte.outer_cte = :param_2) " "SELECT rec_cte.outer_cte FROM rec_cte", + checkparams={"param_1": 1, "param_2": 42}, ) def test_recursive_nesting_cte_in_cte(self): @@ -2219,18 +2285,19 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT cte.outer_cte FROM cte", ) - def test_same_nested_cte_is_not_generated_twice(self): + @testing.fixture + def same_nested_cte_is_not_generated_twice(self): # Same = name and query nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte( "nesting_cte", nesting=True ) select_add_cte = select( - (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value") + (nesting_cte_used_twice.c.inner_cte_1 + 2).label("next_value") ).cte("nesting_2", nesting=True) union_cte = ( select( - (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value") + (nesting_cte_used_twice.c.inner_cte_1 - 3).label("next_value") ) .union(select(select_add_cte)) .cte("wrapper", nesting=True) @@ -2241,9 +2308,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): .add_cte(nesting_cte_used_twice) .union(select(nesting_cte_used_twice)) ) + return stmt + def test_same_nested_cte_is_not_generated_twice_positional( + self, same_nested_cte_is_not_generated_twice + ): self.assert_compile( - stmt, + same_nested_cte_is_not_generated_twice, + "WITH nesting_cte AS " + "(SELECT ? AS inner_cte_1)" + ", wrapper AS " + "(WITH nesting_2 AS " + "(SELECT nesting_cte.inner_cte_1 + ? " + "AS next_value " + "FROM nesting_cte)" + " SELECT nesting_cte.inner_cte_1 - ? " + "AS next_value " + "FROM nesting_cte UNION SELECT nesting_2.next_value " + "AS next_value FROM nesting_2)" + " SELECT wrapper.next_value " + "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 " + "FROM nesting_cte", + checkpositional=(1, 2, 3), + dialect="default_qmark", + ) + + def test_same_nested_cte_is_not_generated_twice( + self, same_nested_cte_is_not_generated_twice + ): + self.assert_compile( + same_nested_cte_is_not_generated_twice, "WITH nesting_cte AS " "(SELECT :param_1 AS inner_cte_1)" ", wrapper AS " @@ -2253,19 +2347,25 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "FROM nesting_cte)" " SELECT nesting_cte.inner_cte_1 - :inner_cte_1_1 " "AS next_value " - "FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value " - "FROM nesting_2)" + "FROM nesting_cte UNION SELECT nesting_2.next_value " + "AS next_value FROM nesting_2)" " SELECT wrapper.next_value " "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 " "FROM nesting_cte", + checkparams={ + "param_1": 1, + "inner_cte_1_2": 2, + "inner_cte_1_1": 3, + }, ) - def test_recursive_nesting_cte_in_recursive_cte(self): + @testing.fixture + def recursive_nesting_cte_in_recursive_cte(self): nesting_cte = select(literal(1).label("inner_cte")).cte( "nesting", nesting=True, recursive=True ) nesting_rec_part = select(nesting_cte.c.inner_cte).where( - nesting_cte.c.inner_cte == literal(1) + nesting_cte.c.inner_cte == literal(2) ) nesting_cte = nesting_cte.union(nesting_rec_part) @@ -2273,14 +2373,37 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "rec_cte", recursive=True ) rec_part = select(rec_cte.c.outer_cte).where( - rec_cte.c.outer_cte == literal(1) + rec_cte.c.outer_cte == literal(3) ) rec_cte = rec_cte.union(rec_part) stmt = select(rec_cte) + return stmt + + def test_recursive_nesting_cte_in_recursive_cte_positional( + self, recursive_nesting_cte_in_recursive_cte + ): self.assert_compile( - stmt, + recursive_nesting_cte_in_recursive_cte, + "WITH RECURSIVE rec_cte(outer_cte) AS (" + "WITH RECURSIVE nesting(inner_cte) AS " + "(SELECT ? AS inner_cte UNION " + "SELECT nesting.inner_cte AS inner_cte FROM nesting " + "WHERE nesting.inner_cte = ?) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " + "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " + "WHERE rec_cte.outer_cte = ?) " + "SELECT rec_cte.outer_cte FROM rec_cte", + checkpositional=(1, 2, 3), + dialect="default_qmark", + ) + + def test_recursive_nesting_cte_in_recursive_cte( + self, recursive_nesting_cte_in_recursive_cte + ): + self.assert_compile( + recursive_nesting_cte_in_recursive_cte, "WITH RECURSIVE rec_cte(outer_cte) AS (" "WITH RECURSIVE nesting(inner_cte) AS " "(SELECT :param_1 AS inner_cte UNION " @@ -2290,6 +2413,7 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " "WHERE rec_cte.outer_cte = :param_3) " "SELECT rec_cte.outer_cte FROM rec_cte", + checkparams={"param_1": 1, "param_2": 2, "param_3": 3}, ) def test_select_from_insert_cte_with_nesting(self): @@ -2418,7 +2542,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): ") SELECT cte.outer_cte FROM cte", ) - def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self): + @testing.fixture + def cte_in_compound_select(self): + upper = select(literal(1).label("z")) + + lower_a_cte = select(literal(2).label("x")).cte("xx", nesting=True) + lower_a = select(literal(3).label("y")).add_cte(lower_a_cte) + lower_b = select(literal(4).label("w")) + + stmt = upper.union_all(lower_a.union_all(lower_b)) + return stmt + + def test_cte_in_compound_select_positional(self, cte_in_compound_select): + self.assert_compile( + cte_in_compound_select, + "SELECT ? AS z UNION ALL (WITH xx AS " + "(SELECT ? AS x) " + "SELECT ? AS y UNION ALL SELECT ? AS w)", + checkpositional=(1, 2, 3, 4), + dialect="default_qmark", + ) + + def test_cte_in_compound_select(self, cte_in_compound_select): + self.assert_compile( + cte_in_compound_select, + "SELECT :param_1 AS z UNION ALL (WITH xx AS " + "(SELECT :param_2 AS x) " + "SELECT :param_3 AS y UNION ALL SELECT :param_4 AS w)", + checkparams={ + "param_1": 1, + "param_2": 2, + "param_3": 3, + "param_4": 4, + }, + ) + + @testing.fixture + def recursive_cte_referenced_multiple_times_with_nesting_cte(self): rec_root = select(literal(1).label("the_value")).cte( "recursive_cte", recursive=True ) @@ -2431,7 +2591,7 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): exists( select(rec_root_ref.c.the_value) .where(rec_root_ref.c.the_value < 10) - .limit(1) + .limit(5) ).label("val") ).cte("should_continue", nesting=True) @@ -2447,13 +2607,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): rec_cte = rec_root.union_all(rec_part) stmt = rec_cte.select() + return stmt + def test_recursive_cte_referenced_multiple_times_with_nesting_cte_pos( + self, recursive_cte_referenced_multiple_times_with_nesting_cte + ): self.assert_compile( - stmt, + recursive_cte_referenced_multiple_times_with_nesting_cte, + "WITH RECURSIVE recursive_cte(the_value) AS (" + "SELECT ? AS the_value UNION ALL (" + "WITH allow_multiple_ref AS (" + "SELECT recursive_cte.the_value AS the_value " + "FROM recursive_cte)" + ", should_continue AS (SELECT EXISTS (" + "SELECT allow_multiple_ref.the_value FROM allow_multiple_ref" + " WHERE allow_multiple_ref.the_value < ?" + " LIMIT ?) AS val) " + "SELECT allow_multiple_ref.the_value * ? AS anon_1" + " FROM allow_multiple_ref, should_continue " + "WHERE should_continue.val != 1" + " UNION ALL SELECT allow_multiple_ref.the_value * ?" + " AS anon_2 FROM allow_multiple_ref, should_continue" + " WHERE should_continue.val != 1))" + " SELECT recursive_cte.the_value FROM recursive_cte", + checkpositional=(1, 10, 5, 2, 3), + dialect="default_qmark", + ) + + def test_recursive_cte_referenced_multiple_times_with_nesting_cte( + self, recursive_cte_referenced_multiple_times_with_nesting_cte + ): + self.assert_compile( + recursive_cte_referenced_multiple_times_with_nesting_cte, "WITH RECURSIVE recursive_cte(the_value) AS (" "SELECT :param_1 AS the_value UNION ALL (" "WITH allow_multiple_ref AS (" - "SELECT recursive_cte.the_value AS the_value FROM recursive_cte)" + "SELECT recursive_cte.the_value AS the_value " + "FROM recursive_cte)" ", should_continue AS (SELECT EXISTS (" "SELECT allow_multiple_ref.the_value FROM allow_multiple_ref" " WHERE allow_multiple_ref.the_value < :the_value_2" @@ -2465,4 +2655,11 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): " AS anon_2 FROM allow_multiple_ref, should_continue" " WHERE should_continue.val != true))" " SELECT recursive_cte.the_value FROM recursive_cte", + checkparams={ + "param_1": 1, + "param_2": 5, + "the_value_2": 10, + "the_value_1": 2, + "the_value_3": 3, + }, ) diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 1695771486..37363273b2 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -193,7 +193,8 @@ class TraversalTest( ("name with~~tildes~~",), argnames="name", ) - def test_bindparam_key_proc_for_copies(self, meth, name): + @testing.combinations(True, False, argnames="positional") + def test_bindparam_key_proc_for_copies(self, meth, name, positional): r"""test :ticket:`6249`. Revised for :ticket:`8056`. @@ -225,13 +226,25 @@ class TraversalTest( token = re.sub(r"[%\(\) \$\[\]]", "_", name) - self.assert_compile( - expr, - '"%(name)s" IN (:%(token)s_1_1, ' - ":%(token)s_1_2, :%(token)s_1_3)" % {"name": name, "token": token}, - render_postcompile=True, - dialect="default", - ) + if positional: + self.assert_compile( + expr, + '"%(name)s" IN (?, ?, ?)' % {"name": name}, + checkpositional=(1, 2, 3), + render_postcompile=True, + dialect="default_qmark", + ) + else: + tokens = ["%s_1_%s" % (token, i) for i in range(1, 4)] + self.assert_compile( + expr, + '"%(name)s" IN (:%(token)s_1_1, ' + ":%(token)s_1_2, :%(token)s_1_3)" + % {"name": name, "token": token}, + checkparams=dict(zip(tokens, [1, 2, 3])), + render_postcompile=True, + dialect="default", + ) def test_expanding_in_bindparam_safe_to_clone(self): expr = column("x").in_([1, 2, 3]) diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 3a9a06728c..908fd9faaf 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -777,6 +777,22 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "OVER (PARTITION BY mytable.description RANGE BETWEEN :param_1 " "FOLLOWING AND :param_2 FOLLOWING) " "AS anon_1 FROM mytable", + checkparams={"name_1": "foo", "param_1": 1, "param_2": 5}, + ) + + def test_funcfilter_windowing_range_positional(self): + self.assert_compile( + select( + func.rank() + .filter(table1.c.name > "foo") + .over(range_=(1, 5), partition_by=["description"]) + ), + "SELECT rank() FILTER (WHERE mytable.name > ?) " + "OVER (PARTITION BY mytable.description RANGE BETWEEN ? " + "FOLLOWING AND ? FOLLOWING) " + "AS anon_1 FROM mytable", + checkpositional=("foo", 1, 5), + dialect="default_qmark", ) def test_funcfilter_windowing_rows(self):