From: Federico Caselli Date: Sat, 19 Nov 2022 19:39:10 +0000 (+0100) Subject: Fix positional compiling bugs X-Git-Tag: rel_2_0_0b4~16^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=0f2baae6bf72353f785bad394684f2d6fa53e0ef;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fix positional compiling bugs Fixed a series of issues regarding positionally rendered bound parameters, such as those used for SQLite, asyncpg, MySQL and others. Some compiled forms would not maintain the order of parameters correctly, such as the PostgreSQL ``regexp_replace()`` function as well as within the "nesting" feature of the :class:`.CTE` construct first introduced in :ticket:`4123`. Fixes: #8827 Change-Id: I9813ed7c358cc5c1e26725c48df546b209a442cb --- diff --git a/doc/build/changelog/unreleased_14/8827.rst b/doc/build/changelog/unreleased_14/8827.rst new file mode 100644 index 0000000000..677277e45d --- /dev/null +++ b/doc/build/changelog/unreleased_14/8827.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, sql + :tickets: 8827 + + Fixed a series of issues regarding positionally rendered bound parameters, + such as those used for SQLite, asyncpg, MySQL and others. Some compiled + forms would not maintain the order of parameters correctly, such as the + PostgreSQL ``regexp_replace()`` function as well as within the "nesting" + feature of the :class:`.CTE` construct first introduced in :ticket:`4123`. diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 0d51bf73d5..41b9ac43d4 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -854,7 +854,7 @@ class OracleCompiler(compiler.SQLCompiler): def visit_function(self, func, **kw): text = super().visit_function(func, **kw) if kw.get("asfrom", False): - text = "TABLE (%s)" % func + text = "TABLE (%s)" % text return text def visit_table_valued_column(self, element, **kw): @@ -1222,20 +1222,18 @@ class OracleCompiler(compiler.SQLCompiler): self.process(binary.right), ) - def _get_regexp_args(self, binary, kw): + def visit_regexp_match_op_binary(self, binary, operator, **kw): string = self.process(binary.left, **kw) pattern = self.process(binary.right, **kw) flags = binary.modifiers["flags"] - if flags is not None: - flags = self.process(flags, **kw) - return string, pattern, flags - - def visit_regexp_match_op_binary(self, binary, operator, **kw): - string, pattern, flags = self._get_regexp_args(binary, kw) if flags is None: return "REGEXP_LIKE(%s, %s)" % (string, pattern) else: - return "REGEXP_LIKE(%s, %s, %s)" % (string, pattern, flags) + return "REGEXP_LIKE(%s, %s, %s)" % ( + string, + pattern, + self.process(flags, **kw), + ) def visit_not_regexp_match_op_binary(self, binary, operator, **kw): return "NOT %s" % self.visit_regexp_match_op_binary( @@ -1243,8 +1241,10 @@ class OracleCompiler(compiler.SQLCompiler): ) def visit_regexp_replace_op_binary(self, binary, operator, **kw): - string, pattern, flags = self._get_regexp_args(binary, kw) + string = self.process(binary.left, **kw) + pattern = self.process(binary.right, **kw) replacement = self.process(binary.modifiers["replacement"], **kw) + flags = binary.modifiers["flags"] if flags is None: return "REGEXP_REPLACE(%s, %s, %s)" % ( string, @@ -1256,7 +1256,7 @@ class OracleCompiler(compiler.SQLCompiler): string, pattern, replacement, - flags, + self.process(flags, **kw), ) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 99c48fb2f9..f9108094f2 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1747,14 +1747,11 @@ class PGCompiler(compiler.SQLCompiler): return self._generate_generic_binary( binary, " %s* " % base_op, **kw ) - flags = self.process(flags, **kw) - string = self.process(binary.left, **kw) - pattern = self.process(binary.right, **kw) return "%s %s CONCAT('(?', %s, ')', %s)" % ( - string, + self.process(binary.left, **kw), base_op, - flags, - pattern, + self.process(flags, **kw), + self.process(binary.right, **kw), ) def visit_regexp_match_op_binary(self, binary, operator, **kw): @@ -1767,8 +1764,6 @@ class PGCompiler(compiler.SQLCompiler): string = self.process(binary.left, **kw) pattern = self.process(binary.right, **kw) flags = binary.modifiers["flags"] - if flags is not None: - flags = self.process(flags, **kw) replacement = self.process(binary.modifiers["replacement"], **kw) if flags is None: return "REGEXP_REPLACE(%s, %s, %s)" % ( @@ -1781,7 +1776,7 @@ class PGCompiler(compiler.SQLCompiler): string, pattern, replacement, - flags, + self.process(flags, **kw), ) def visit_empty_set_expr(self, element_types): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 50cf9b477c..7ac279ee2e 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -236,8 +236,8 @@ BIND_TEMPLATES = { } -_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]") -_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__")) +_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]") +_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___")) OPERATORS = { # binary @@ -973,6 +973,7 @@ class SQLCompiler(Compiled): debugging use cases. """ + positiontup_level: Optional[Dict[str, int]] = None inline: bool = False @@ -988,6 +989,8 @@ class SQLCompiler(Compiled): ctes_recursive: bool cte_positional: Dict[CTE, List[str]] + cte_level: Dict[CTE, int] + cte_order: Dict[Optional[CTE], List[CTE]] def __init__( self, @@ -1052,6 +1055,7 @@ class SQLCompiler(Compiled): # true if the paramstyle is positional self.positional = dialect.positional if self.positional: + self.positiontup_level = {} self.positiontup = [] self._numeric_binds = dialect.paramstyle == "numeric" self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] @@ -1215,6 +1219,8 @@ class SQLCompiler(Compiled): self.ctes_recursive = False if self.positional: self.cte_positional = {} + self.cte_level = {} + self.cte_order = collections.defaultdict(list) return ctes @@ -2103,7 +2109,13 @@ class SQLCompiler(Compiled): text = self.process(taf.element, **kw) if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) @@ -2231,6 +2243,7 @@ class SQLCompiler(Compiled): ) def visit_over(self, over, **kwargs): + text = over.element._compiler_dispatch(self, **kwargs) if over.range_: range_ = "RANGE BETWEEN %s" % self._format_frame_clause( over.range_, **kwargs @@ -2243,7 +2256,7 @@ class SQLCompiler(Compiled): range_ = None return "%s OVER (%s)" % ( - over.element._compiler_dispatch(self, **kwargs), + text, " ".join( [ "%s BY %s" @@ -2396,7 +2409,9 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = ( self._render_cte_clause( - nesting_level=nesting_level, include_following_stack=True + nesting_level=nesting_level, + include_following_stack=True, + visiting_cte=kwargs.get("visiting_cte"), ) + text ) @@ -3222,7 +3237,8 @@ class SQLCompiler(Compiled): positional_names.append(name) else: self.positiontup.append(name) # type: ignore[union-attr] - elif not escaped_from: + self.positiontup_level[name] = len(self.stack) # type: ignore[index] # noqa: E501 + if not escaped_from: if _BIND_TRANSLATE_RE.search(name): # not quite the translate use case as we want to @@ -3333,6 +3349,8 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = new_level_name + ( cte_opts, ) + if self.positional: + self.cte_level[cte] = cte_level else: cte_level = len(self.stack) if nesting else 1 @@ -3396,6 +3414,8 @@ class SQLCompiler(Compiled): self.level_name_by_cte[_reference_cte] = cte_level_name + ( cte_opts, ) + if self.positional: + self.cte_level[cte] = cte_level if pre_alias_cte not in self.ctes: self.visit_cte(pre_alias_cte, **kwargs) @@ -4129,13 +4149,16 @@ class SQLCompiler(Compiled): if per_dialect: text += " " + self.get_statement_hint_text(per_dialect) - if self.ctes: - # In compound query, CTEs are shared at the compound level - if not is_embedded_select: - nesting_level = len(self.stack) if not toplevel else None - text = ( - self._render_cte_clause(nesting_level=nesting_level) + text + # In compound query, CTEs are shared at the compound level + if self.ctes and (not is_embedded_select or toplevel): + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kwargs.get("visiting_cte"), ) + + text + ) if select_stmt._suffixes: text += " " + self._generate_prefixes( @@ -4309,6 +4332,7 @@ class SQLCompiler(Compiled): self, nesting_level=None, include_following_stack=False, + visiting_cte=None, ): """ include_following_stack @@ -4341,19 +4365,47 @@ class SQLCompiler(Compiled): if not ctes: return "" - ctes_recursive = any([cte.recursive for cte in ctes]) if self.positional: - assert self.positiontup is not None - self.positiontup = ( - list( - itertools.chain.from_iterable( - self.cte_positional[cte] for cte in ctes - ) + self.cte_order[visiting_cte].extend(ctes) + + if visiting_cte is None and self.cte_order: + assert self.positiontup is not None + + def get_nested_positional(cte): + if cte in self.cte_order: + children = self.cte_order.pop(cte) + to_add = list( + itertools.chain.from_iterable( + get_nested_positional(child_cte) + for child_cte in children + ) + ) + if cte in self.cte_positional: + return reorder_positional( + self.cte_positional[cte], + to_add, + self.cte_level[children[0]], + ) + else: + return to_add + else: + return self.cte_positional.get(cte, []) + + def reorder_positional(pos, to_add, level): + if not level: + return to_add + pos + index = 0 + for index, name in enumerate(reversed(pos)): + if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501 + break + return pos[:-index] + to_add + pos[-index:] + + to_add = get_nested_positional(None) + self.positiontup = reorder_positional( + self.positiontup, to_add, nesting_level ) - + self.positiontup - ) cte_text = self.get_cte_preamble(ctes_recursive) + " " cte_text += ", \n".join([txt for txt in ctes.values()]) @@ -4930,6 +4982,7 @@ class SQLCompiler(Compiled): self._render_cte_clause( nesting_level=nesting_level, include_following_stack=True, + visiting_cte=kw.get("visiting_cte"), ), select_text, ) @@ -4997,7 +5050,9 @@ class SQLCompiler(Compiled): nesting_level = len(self.stack) if not toplevel else None text = ( self._render_cte_clause( - nesting_level=nesting_level, include_following_stack=True + nesting_level=nesting_level, + include_following_stack=True, + visiting_cte=kw.get("visiting_cte"), ) + text ) @@ -5146,7 +5201,13 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) @@ -5260,7 +5321,13 @@ class SQLCompiler(Compiled): if self.ctes: nesting_level = len(self.stack) if not toplevel else None - text = self._render_cte_clause(nesting_level=nesting_level) + text + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + visiting_cte=kw.get("visiting_cte"), + ) + + text + ) self.stack.pop(-1) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 97336d4166..2dcc611fa5 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2052,9 +2052,7 @@ class CTE( else: self.element._generate_fromclause_column_proxies(self) - def alias( - self, name: Optional[str] = None, flat: bool = False - ) -> NamedFromClause: + def alias(self, name: Optional[str] = None, flat: bool = False) -> CTE: """Return an :class:`_expression.Alias` of this :class:`_expression.CTE`. @@ -2078,7 +2076,7 @@ class CTE( _suffixes=self._suffixes, ) - def union(self, *other): + def union(self, *other: _SelectStatementForCompoundArgument) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION`` of the original CTE against the given selectables provided as positional arguments. @@ -2107,7 +2105,7 @@ class CTE( _suffixes=self._suffixes, ) - def union_all(self, *other): + def union_all(self, *other: _SelectStatementForCompoundArgument) -> CTE: r"""Return a new :class:`_expression.CTE` with a SQL ``UNION ALL`` of the original CTE against the given selectables provided as positional arguments. diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 321c05b446..790a72ec84 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -9,7 +9,9 @@ from __future__ import annotations +from collections import defaultdict import contextlib +from copy import copy from itertools import filterfalse import re import sys @@ -493,6 +495,7 @@ class AssertsCompiledSQL: render_schema_translate=False, default_schema_name=None, from_linting=False, + check_param_order=True, ): if use_default_dialect: dialect = default.DefaultDialect() @@ -506,8 +509,11 @@ class AssertsCompiledSQL: if dialect is None: dialect = config.db.dialect - elif dialect == "default": - dialect = default.DefaultDialect() + elif dialect == "default" or dialect == "default_qmark": + if dialect == "default": + dialect = default.DefaultDialect() + else: + dialect = default.DefaultDialect("qmark") dialect.supports_default_values = supports_default_values dialect.supports_default_metavalue = supports_default_metavalue elif dialect == "default_enhanced": @@ -632,7 +638,7 @@ class AssertsCompiledSQL: if checkparams is not None: eq_(c.construct_params(params), checkparams) if checkpositional is not None: - p = c.construct_params(params) + p = c.construct_params(params, escape_names=False) eq_(tuple([p[x] for x in c.positiontup]), checkpositional) if check_prefetch is not None: eq_(c.prefetch, check_prefetch) @@ -652,6 +658,58 @@ class AssertsCompiledSQL: }, check_post_param, ) + if check_param_order and getattr(c, "params", None): + + def get_dialect(paramstyle, positional): + cp = copy(dialect) + cp.paramstyle = paramstyle + cp.positional = positional + return cp + + pyformat_dialect = get_dialect("pyformat", False) + pyformat_c = clause.compile(dialect=pyformat_dialect, **kw) + stmt = re.sub(r"[\n\t]", "", str(pyformat_c)) + + qmark_dialect = get_dialect("qmark", True) + qmark_c = clause.compile(dialect=qmark_dialect, **kw) + values = list(qmark_c.positiontup) + escaped = qmark_c.escaped_bind_names + + for post_param in ( + qmark_c.post_compile_params | qmark_c.literal_execute_params + ): + name = qmark_c.bind_names[post_param] + if name in values: + values = [v for v in values if v != name] + positions = [] + pos_by_value = defaultdict(list) + for v in values: + try: + if v in pos_by_value: + start = pos_by_value[v][-1] + else: + start = 0 + esc = escaped.get(v, v) + pos = stmt.index("%%(%s)s" % (esc,), start) + 2 + positions.append(pos) + pos_by_value[v].append(pos) + except ValueError: + msg = "Expected to find bindparam %r in %r" % (v, stmt) + assert False, msg + + ordered = all( + positions[i - 1] < positions[i] + for i in range(1, len(positions)) + ) + + expected = [v for _, v in sorted(zip(positions, values))] + + msg = ( + "Order of parameters %s does not match the order " + "in the statement %s. Statement %r" % (values, expected, stmt) + ) + + is_true(ordered, msg) class ComparesTables: diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index dff8584e33..255efdb3c0 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -1770,11 +1770,11 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.table.c.myid.regexp_replace( "pattern", "replacement", flags="ig" ), - "REGEXP_REPLACE(mytable.myid, :myid_1, :myid_3, :myid_2)", + "REGEXP_REPLACE(mytable.myid, :myid_1, :myid_2, :myid_3)", checkparams={ "myid_1": "pattern", - "myid_3": "replacement", - "myid_2": "ig", + "myid_2": "replacement", + "myid_3": "ig", }, ) diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index ee3372c749..cf5f1c8267 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -3415,11 +3415,11 @@ class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.table.c.myid.regexp_replace( "pattern", "replacement", flags="ig" ), - "REGEXP_REPLACE(mytable.myid, %(myid_1)s, %(myid_3)s, %(myid_2)s)", + "REGEXP_REPLACE(mytable.myid, %(myid_1)s, %(myid_2)s, %(myid_3)s)", checkparams={ "myid_1": "pattern", - "myid_3": "replacement", - "myid_2": "ig", + "myid_2": "replacement", + "myid_3": "ig", }, ) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index c71cfd61f0..205ce51576 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -4880,6 +4880,124 @@ class BindParameterTest(AssertsCompiledSQL, fixtures.TestBase): stmt, expected, literal_binds=True, params=params ) + standalone_escape = testing.combinations( + ("normalname", "normalname"), + ("_name", "_name"), + ("[BracketsAndCase]", "_BracketsAndCase_"), + ("has spaces", "has_spaces"), + argnames="paramname, expected", + ) + + @standalone_escape + @testing.variation("use_positional", [True, False]) + def test_standalone_bindparam_escape( + self, paramname, expected, use_positional + ): + stmt = select(table1.c.myid).where( + table1.c.name == bindparam(paramname, value="x") + ) + + if use_positional: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable WHERE mytable.name = ?", + params={paramname: "y"}, + checkpositional=("y",), + dialect="sqlite", + ) + else: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable WHERE mytable.name = :%s" + % (expected,), + params={paramname: "y"}, + checkparams={expected: "y"}, + dialect="default", + ) + + @standalone_escape + @testing.variation("use_assert_compile", [True, False]) + @testing.variation("use_positional", [True, False]) + def test_standalone_bindparam_escape_expanding( + self, paramname, expected, use_assert_compile, use_positional + ): + stmt = select(table1.c.myid).where( + table1.c.name.in_(bindparam(paramname, value=["a", "b"])) + ) + + if use_assert_compile: + if use_positional: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable " + "WHERE mytable.name IN (?, ?)", + params={paramname: ["y", "z"]}, + # NOTE: this is what render_postcompile will do right now + # if you run construct_params(). render_postcompile mode + # is not actually used by the execution internals, it's for + # user-facing compilation code. So this is likely a + # current limitation of construct_params() which is not + # doing the full blown postcompile; just assert that's + # what it does for now. it likely should be corrected + # to make more sense. + checkpositional=(["y", "z"], ["y", "z"]), + dialect="sqlite", + render_postcompile=True, + ) + else: + self.assert_compile( + stmt, + "SELECT mytable.myid FROM mytable WHERE mytable.name IN " + "(:%s_1, :%s_2)" % (expected, expected), + params={paramname: ["y", "z"]}, + # NOTE: this is what render_postcompile will do right now + # if you run construct_params(). render_postcompile mode + # is not actually used by the execution internals, it's for + # user-facing compilation code. So this is likely a + # current limitation of construct_params() which is not + # doing the full blown postcompile; just assert that's + # what it does for now. it likely should be corrected + # to make more sense. + checkparams={ + "%s_1" % expected: ["y", "z"], + "%s_2" % expected: ["y", "z"], + }, + dialect="default", + render_postcompile=True, + ) + else: + # this is what DefaultDialect actually does. + # this should be matched to DefaultDialect._init_compiled() + if use_positional: + compiled = stmt.compile( + dialect=default.DefaultDialect(paramstyle="qmark") + ) + else: + compiled = stmt.compile(dialect=default.DefaultDialect()) + + checkparams = compiled.construct_params( + {paramname: ["y", "z"]}, escape_names=False + ) + + # nothing actually happened. if the compiler had + # render_postcompile set, the + # above weird param thing happens + eq_(checkparams, {paramname: ["y", "z"]}) + + expanded_state = compiled._process_parameters_for_postcompile( + checkparams + ) + eq_( + expanded_state.additional_parameters, + {f"{expected}_1": "y", f"{expected}_2": "z"}, + ) + + if use_positional: + eq_( + expanded_state.positiontup, + [f"{expected}_1", f"{expected}_2"], + ) + class UnsupportedTest(fixtures.TestBase): def test_unsupported_element_str_visit_name(self): diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index f369518fc7..b89d18de62 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -2238,7 +2238,8 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): ") SELECT cte.outer_cte FROM cte", ) - def test_nesting_cte_in_recursive_cte(self): + @testing.fixture + def nesting_cte_in_recursive_cte(self): nesting_cte = select(literal(1).label("inner_cte")).cte( "nesting", nesting=True ) @@ -2247,23 +2248,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "rec_cte", recursive=True ) rec_part = select(rec_cte.c.outer_cte).where( - rec_cte.c.outer_cte == literal(1) + rec_cte.c.outer_cte == literal(42) ) rec_cte = rec_cte.union(rec_part) stmt = select(rec_cte) + return stmt + + def test_nesting_cte_in_recursive_cte_positional( + self, nesting_cte_in_recursive_cte + ): self.assert_compile( - stmt, + nesting_cte_in_recursive_cte, + "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS " + "(SELECT ? AS inner_cte) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " + "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " + "WHERE rec_cte.outer_cte = ?) " + "SELECT rec_cte.outer_cte FROM rec_cte", + checkpositional=(1, 42), + dialect="default_qmark", + ) + + def test_nesting_cte_in_recursive_cte(self, nesting_cte_in_recursive_cte): + self.assert_compile( + nesting_cte_in_recursive_cte, "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS " "(SELECT :param_1 AS inner_cte) " "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " "WHERE rec_cte.outer_cte = :param_2) " "SELECT rec_cte.outer_cte FROM rec_cte", + checkparams={"param_1": 1, "param_2": 42}, ) - def test_nesting_cte_in_recursive_cte_w_add_cte(self): + @testing.fixture + def nesting_cte_in_recursive_cte_w_add_cte(self): nesting_cte = select(literal(1).label("inner_cte")).cte( "nesting", nesting=True ) @@ -2272,20 +2293,40 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "rec_cte", recursive=True ) rec_part = select(rec_cte.c.outer_cte).where( - rec_cte.c.outer_cte == literal(1) + rec_cte.c.outer_cte == literal(42) ) rec_cte = rec_cte.union(rec_part) stmt = select(rec_cte) + return stmt + def test_nesting_cte_in_recursive_cte_w_add_cte_positional( + self, nesting_cte_in_recursive_cte_w_add_cte + ): self.assert_compile( - stmt, + nesting_cte_in_recursive_cte_w_add_cte, + "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS " + "(SELECT ? AS inner_cte) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " + "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " + "WHERE rec_cte.outer_cte = ?) " + "SELECT rec_cte.outer_cte FROM rec_cte", + checkpositional=(1, 42), + dialect="default_qmark", + ) + + def test_nesting_cte_in_recursive_cte_w_add_cte( + self, nesting_cte_in_recursive_cte_w_add_cte + ): + self.assert_compile( + nesting_cte_in_recursive_cte_w_add_cte, "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS " "(SELECT :param_1 AS inner_cte) " "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " "WHERE rec_cte.outer_cte = :param_2) " "SELECT rec_cte.outer_cte FROM rec_cte", + checkparams={"param_1": 1, "param_2": 42}, ) def test_recursive_nesting_cte_in_cte(self): @@ -2387,18 +2428,19 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT cte.outer_cte FROM cte", ) - def test_same_nested_cte_is_not_generated_twice(self): + @testing.fixture + def same_nested_cte_is_not_generated_twice(self): # Same = name and query nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte( "nesting_cte", nesting=True ) select_add_cte = select( - (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value") + (nesting_cte_used_twice.c.inner_cte_1 + 2).label("next_value") ).cte("nesting_2", nesting=True) union_cte = ( select( - (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value") + (nesting_cte_used_twice.c.inner_cte_1 - 3).label("next_value") ) .union(select(select_add_cte)) .cte("wrapper", nesting=True) @@ -2409,9 +2451,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): .add_cte(nesting_cte_used_twice) .union(select(nesting_cte_used_twice)) ) + return stmt + def test_same_nested_cte_is_not_generated_twice_positional( + self, same_nested_cte_is_not_generated_twice + ): self.assert_compile( - stmt, + same_nested_cte_is_not_generated_twice, + "WITH nesting_cte AS " + "(SELECT ? AS inner_cte_1)" + ", wrapper AS " + "(WITH nesting_2 AS " + "(SELECT nesting_cte.inner_cte_1 + ? " + "AS next_value " + "FROM nesting_cte)" + " SELECT nesting_cte.inner_cte_1 - ? " + "AS next_value " + "FROM nesting_cte UNION SELECT nesting_2.next_value " + "AS next_value FROM nesting_2)" + " SELECT wrapper.next_value " + "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 " + "FROM nesting_cte", + checkpositional=(1, 2, 3), + dialect="default_qmark", + ) + + def test_same_nested_cte_is_not_generated_twice( + self, same_nested_cte_is_not_generated_twice + ): + self.assert_compile( + same_nested_cte_is_not_generated_twice, "WITH nesting_cte AS " "(SELECT :param_1 AS inner_cte_1)" ", wrapper AS " @@ -2421,11 +2490,16 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "FROM nesting_cte)" " SELECT nesting_cte.inner_cte_1 - :inner_cte_1_1 " "AS next_value " - "FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value " - "FROM nesting_2)" + "FROM nesting_cte UNION SELECT nesting_2.next_value " + "AS next_value FROM nesting_2)" " SELECT wrapper.next_value " "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 " "FROM nesting_cte", + checkparams={ + "param_1": 1, + "inner_cte_1_2": 2, + "inner_cte_1_1": 3, + }, ) def test_add_cte_dont_nest_in_two_places(self): @@ -2458,18 +2532,19 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): ): stmt.compile() - def test_same_nested_cte_is_not_generated_twice_w_add_cte(self): + @testing.fixture + def same_nested_cte_is_not_generated_twice_w_add_cte(self): # Same = name and query nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte( "nesting_cte" ) select_add_cte = select( - (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value") + (nesting_cte_used_twice.c.inner_cte_1 + 2).label("next_value") ).cte("nesting_2") union_cte = ( select( - (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value") + (nesting_cte_used_twice.c.inner_cte_1 - 3).label("next_value") ) .add_cte(nesting_cte_used_twice) .union( @@ -2483,31 +2558,60 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): .add_cte(nesting_cte_used_twice, nest_here=True) .union(select(nesting_cte_used_twice)) ) + return stmt + def test_same_nested_cte_is_not_generated_twice_w_add_cte_positional( + self, same_nested_cte_is_not_generated_twice_w_add_cte + ): self.assert_compile( - stmt, - "WITH nesting_cte AS " - "(SELECT :param_1 AS inner_cte_1)" - ", wrapper AS " - "(WITH nesting_2 AS " - "(SELECT nesting_cte.inner_cte_1 + :inner_cte_1_2 " + same_nested_cte_is_not_generated_twice_w_add_cte, + "WITH nesting_cte AS (SELECT ? AS inner_cte_1)" + ", wrapper AS (WITH nesting_2 AS " + "(SELECT nesting_cte.inner_cte_1 + ? " + "AS next_value FROM nesting_cte)" + " SELECT nesting_cte.inner_cte_1 - ? " "AS next_value " - "FROM nesting_cte)" + "FROM nesting_cte UNION " + "SELECT nesting_2.next_value AS next_value " + "FROM nesting_2)" + " SELECT wrapper.next_value " + "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 " + "FROM nesting_cte", + checkpositional=(1, 2, 3), + dialect="default_qmark", + ) + + def test_same_nested_cte_is_not_generated_twice_w_add_cte( + self, same_nested_cte_is_not_generated_twice_w_add_cte + ): + self.assert_compile( + same_nested_cte_is_not_generated_twice_w_add_cte, + "WITH nesting_cte AS (SELECT :param_1 AS inner_cte_1)" + ", wrapper AS (WITH nesting_2 AS " + "(SELECT nesting_cte.inner_cte_1 + :inner_cte_1_2 " + "AS next_value FROM nesting_cte)" " SELECT nesting_cte.inner_cte_1 - :inner_cte_1_1 " "AS next_value " - "FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value " + "FROM nesting_cte UNION " + "SELECT nesting_2.next_value AS next_value " "FROM nesting_2)" " SELECT wrapper.next_value " "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 " "FROM nesting_cte", + checkparams={ + "param_1": 1, + "inner_cte_1_2": 2, + "inner_cte_1_1": 3, + }, ) - def test_recursive_nesting_cte_in_recursive_cte(self): + @testing.fixture + def recursive_nesting_cte_in_recursive_cte(self): nesting_cte = select(literal(1).label("inner_cte")).cte( "nesting", nesting=True, recursive=True ) nesting_rec_part = select(nesting_cte.c.inner_cte).where( - nesting_cte.c.inner_cte == literal(1) + nesting_cte.c.inner_cte == literal(2) ) nesting_cte = nesting_cte.union(nesting_rec_part) @@ -2515,14 +2619,37 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "rec_cte", recursive=True ) rec_part = select(rec_cte.c.outer_cte).where( - rec_cte.c.outer_cte == literal(1) + rec_cte.c.outer_cte == literal(3) ) rec_cte = rec_cte.union(rec_part) stmt = select(rec_cte) + return stmt + + def test_recursive_nesting_cte_in_recursive_cte_positional( + self, recursive_nesting_cte_in_recursive_cte + ): self.assert_compile( - stmt, + recursive_nesting_cte_in_recursive_cte, + "WITH RECURSIVE rec_cte(outer_cte) AS (" + "WITH RECURSIVE nesting(inner_cte) AS " + "(SELECT ? AS inner_cte UNION " + "SELECT nesting.inner_cte AS inner_cte FROM nesting " + "WHERE nesting.inner_cte = ?) " + "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION " + "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " + "WHERE rec_cte.outer_cte = ?) " + "SELECT rec_cte.outer_cte FROM rec_cte", + checkpositional=(1, 2, 3), + dialect="default_qmark", + ) + + def test_recursive_nesting_cte_in_recursive_cte( + self, recursive_nesting_cte_in_recursive_cte + ): + self.assert_compile( + recursive_nesting_cte_in_recursive_cte, "WITH RECURSIVE rec_cte(outer_cte) AS (" "WITH RECURSIVE nesting(inner_cte) AS " "(SELECT :param_1 AS inner_cte UNION " @@ -2532,6 +2659,7 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte " "WHERE rec_cte.outer_cte = :param_3) " "SELECT rec_cte.outer_cte FROM rec_cte", + checkparams={"param_1": 1, "param_2": 2, "param_3": 3}, ) def test_select_from_insert_cte_with_nesting(self): @@ -2686,7 +2814,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): ") SELECT cte.outer_cte FROM cte", ) - def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self): + @testing.fixture + def cte_in_compound_select(self): + upper = select(literal(1).label("z")) + + lower_a_cte = select(literal(2).label("x")).cte("xx", nesting=True) + lower_a = select(literal(3).label("y")).add_cte(lower_a_cte) + lower_b = select(literal(4).label("w")) + + stmt = upper.union_all(lower_a.union_all(lower_b)) + return stmt + + def test_cte_in_compound_select_positional(self, cte_in_compound_select): + self.assert_compile( + cte_in_compound_select, + "SELECT ? AS z UNION ALL (WITH xx AS " + "(SELECT ? AS x) " + "SELECT ? AS y UNION ALL SELECT ? AS w)", + checkpositional=(1, 2, 3, 4), + dialect="default_qmark", + ) + + def test_cte_in_compound_select(self, cte_in_compound_select): + self.assert_compile( + cte_in_compound_select, + "SELECT :param_1 AS z UNION ALL (WITH xx AS " + "(SELECT :param_2 AS x) " + "SELECT :param_3 AS y UNION ALL SELECT :param_4 AS w)", + checkparams={ + "param_1": 1, + "param_2": 2, + "param_3": 3, + "param_4": 4, + }, + ) + + @testing.fixture + def recursive_cte_referenced_multiple_times_with_nesting_cte(self): rec_root = select(literal(1).label("the_value")).cte( "recursive_cte", recursive=True ) @@ -2699,7 +2863,7 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): exists( select(rec_root_ref.c.the_value) .where(rec_root_ref.c.the_value < 10) - .limit(1) + .limit(5) ).label("val") ).cte("should_continue", nesting=True) @@ -2715,13 +2879,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): rec_cte = rec_root.union_all(rec_part) stmt = rec_cte.select() + return stmt + def test_recursive_cte_referenced_multiple_times_with_nesting_cte_pos( + self, recursive_cte_referenced_multiple_times_with_nesting_cte + ): self.assert_compile( - stmt, + recursive_cte_referenced_multiple_times_with_nesting_cte, + "WITH RECURSIVE recursive_cte(the_value) AS (" + "SELECT ? AS the_value UNION ALL (" + "WITH allow_multiple_ref AS (" + "SELECT recursive_cte.the_value AS the_value " + "FROM recursive_cte)" + ", should_continue AS (SELECT EXISTS (" + "SELECT allow_multiple_ref.the_value FROM allow_multiple_ref" + " WHERE allow_multiple_ref.the_value < ?" + " LIMIT ?) AS val) " + "SELECT allow_multiple_ref.the_value * ? AS anon_1" + " FROM allow_multiple_ref, should_continue " + "WHERE should_continue.val != 1" + " UNION ALL SELECT allow_multiple_ref.the_value * ?" + " AS anon_2 FROM allow_multiple_ref, should_continue" + " WHERE should_continue.val != 1))" + " SELECT recursive_cte.the_value FROM recursive_cte", + checkpositional=(1, 10, 5, 2, 3), + dialect="default_qmark", + ) + + def test_recursive_cte_referenced_multiple_times_with_nesting_cte( + self, recursive_cte_referenced_multiple_times_with_nesting_cte + ): + self.assert_compile( + recursive_cte_referenced_multiple_times_with_nesting_cte, "WITH RECURSIVE recursive_cte(the_value) AS (" "SELECT :param_1 AS the_value UNION ALL (" "WITH allow_multiple_ref AS (" - "SELECT recursive_cte.the_value AS the_value FROM recursive_cte)" + "SELECT recursive_cte.the_value AS the_value " + "FROM recursive_cte)" ", should_continue AS (SELECT EXISTS (" "SELECT allow_multiple_ref.the_value FROM allow_multiple_ref" " WHERE allow_multiple_ref.the_value < :the_value_2" @@ -2733,6 +2927,13 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL): " AS anon_2 FROM allow_multiple_ref, should_continue" " WHERE should_continue.val != true))" " SELECT recursive_cte.the_value FROM recursive_cte", + checkparams={ + "param_1": 1, + "param_2": 5, + "the_value_2": 10, + "the_value_1": 2, + "the_value_3": 3, + }, ) @testing.combinations(True, False) diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 158707c6ae..8940276e33 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -194,7 +194,8 @@ class TraversalTest( ("name with~~tildes~~",), argnames="name", ) - def test_bindparam_key_proc_for_copies(self, meth, name): + @testing.combinations(True, False, argnames="positional") + def test_bindparam_key_proc_for_copies(self, meth, name, positional): r"""test :ticket:`6249`. Revised for :ticket:`8056`. @@ -226,13 +227,25 @@ class TraversalTest( token = re.sub(r"[%\(\) \$\[\]]", "_", name) - self.assert_compile( - expr, - '"%(name)s" IN (:%(token)s_1_1, ' - ":%(token)s_1_2, :%(token)s_1_3)" % {"name": name, "token": token}, - render_postcompile=True, - dialect="default", - ) + if positional: + self.assert_compile( + expr, + '"%(name)s" IN (?, ?, ?)' % {"name": name}, + checkpositional=(1, 2, 3), + render_postcompile=True, + dialect="default_qmark", + ) + else: + tokens = ["%s_1_%s" % (token, i) for i in range(1, 4)] + self.assert_compile( + expr, + '"%(name)s" IN (:%(token)s_1_1, ' + ":%(token)s_1_2, :%(token)s_1_3)" + % {"name": name, "token": token}, + checkparams=dict(zip(tokens, [1, 2, 3])), + render_postcompile=True, + dialect="default", + ) def test_expanding_in_bindparam_safe_to_clone(self): expr = column("x").in_([1, 2, 3]) diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index c97c136249..1dafe3e8a5 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -774,6 +774,22 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "OVER (PARTITION BY mytable.description RANGE BETWEEN :param_1 " "FOLLOWING AND :param_2 FOLLOWING) " "AS anon_1 FROM mytable", + checkparams={"name_1": "foo", "param_1": 1, "param_2": 5}, + ) + + def test_funcfilter_windowing_range_positional(self): + self.assert_compile( + select( + func.rank() + .filter(table1.c.name > "foo") + .over(range_=(1, 5), partition_by=["description"]) + ), + "SELECT rank() FILTER (WHERE mytable.name > ?) " + "OVER (PARTITION BY mytable.description RANGE BETWEEN ? " + "FOLLOWING AND ? FOLLOWING) " + "AS anon_1 FROM mytable", + checkpositional=("foo", 1, 5), + dialect="default_qmark", ) def test_funcfilter_windowing_rows(self):