--- /dev/null
+.. change::
+ :tags: bug, sql
+ :tickets: 8827
+
+ Fixed a series of issues regarding positionally rendered bound parameters,
+ such as those used for SQLite, asyncpg, MySQL and others. Some compiled
+ forms would not maintain the order of parameters correctly, such as the
+ PostgreSQL ``regexp_replace()`` function as well as within the "nesting"
+ feature of the :class:`.CTE` construct first introduced in :ticket:`4123`.
def visit_function(self, func, **kw):
text = super(OracleCompiler, self).visit_function(func, **kw)
if kw.get("asfrom", False):
- text = "TABLE (%s)" % func
+ text = "TABLE (%s)" % text
return text
def visit_table_valued_column(self, element, **kw):
self.process(binary.right),
)
- def _get_regexp_args(self, binary, kw):
+ def visit_regexp_match_op_binary(self, binary, operator, **kw):
string = self.process(binary.left, **kw)
pattern = self.process(binary.right, **kw)
flags = binary.modifiers["flags"]
- if flags is not None:
- flags = self.process(flags, **kw)
- return string, pattern, flags
-
- def visit_regexp_match_op_binary(self, binary, operator, **kw):
- string, pattern, flags = self._get_regexp_args(binary, kw)
if flags is None:
return "REGEXP_LIKE(%s, %s)" % (string, pattern)
else:
- return "REGEXP_LIKE(%s, %s, %s)" % (string, pattern, flags)
+ return "REGEXP_LIKE(%s, %s, %s)" % (
+ string,
+ pattern,
+ self.process(flags, **kw),
+ )
def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
return "NOT %s" % self.visit_regexp_match_op_binary(
)
def visit_regexp_replace_op_binary(self, binary, operator, **kw):
- string, pattern, flags = self._get_regexp_args(binary, kw)
+ string = self.process(binary.left, **kw)
+ pattern = self.process(binary.right, **kw)
replacement = self.process(binary.modifiers["replacement"], **kw)
+ flags = binary.modifiers["flags"]
if flags is None:
return "REGEXP_REPLACE(%s, %s, %s)" % (
string,
string,
pattern,
replacement,
- flags,
+ self.process(flags, **kw),
)
return self._generate_generic_binary(
binary, " %s* " % base_op, **kw
)
- flags = self.process(flags, **kw)
- string = self.process(binary.left, **kw)
- pattern = self.process(binary.right, **kw)
return "%s %s CONCAT('(?', %s, ')', %s)" % (
- string,
+ self.process(binary.left, **kw),
base_op,
- flags,
- pattern,
+ self.process(flags, **kw),
+ self.process(binary.right, **kw),
)
def visit_regexp_match_op_binary(self, binary, operator, **kw):
string = self.process(binary.left, **kw)
pattern = self.process(binary.right, **kw)
flags = binary.modifiers["flags"]
- if flags is not None:
- flags = self.process(flags, **kw)
replacement = self.process(binary.modifiers["replacement"], **kw)
if flags is None:
return "REGEXP_REPLACE(%s, %s, %s)" % (
string,
pattern,
replacement,
- flags,
+ self.process(flags, **kw),
)
def visit_empty_set_expr(self, element_types):
"named": ":%(name)s",
}
-_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\]]")
-_BIND_TRANSLATE_CHARS = dict(zip("%():[]", "PAZC__"))
+_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]")
+_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___"))
OPERATORS = {
# binary
debugging use cases.
"""
+ positiontup_level = None
inline = False
# true if the paramstyle is positional
self.positional = dialect.positional
if self.positional:
+ self.positiontup_level = {}
self.positiontup = []
self._numeric_binds = dialect.paramstyle == "numeric"
self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
self.ctes_recursive = False
if self.positional:
self.cte_positional = {}
+ self.cte_level = {}
+ self.cte_order = collections.defaultdict(list)
@contextlib.contextmanager
def _nested_result(self):
text = self.process(taf.element, **kw)
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)
)
def visit_over(self, over, **kwargs):
+ text = over.element._compiler_dispatch(self, **kwargs)
if over.range_:
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
over.range_, **kwargs
range_ = None
return "%s OVER (%s)" % (
- over.element._compiler_dispatch(self, **kwargs),
+ text,
" ".join(
[
"%s BY %s"
nesting_level = len(self.stack) if not toplevel else None
text = (
self._render_cte_clause(
- nesting_level=nesting_level, include_following_stack=True
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ visiting_cte=kwargs.get("visiting_cte"),
)
+ text
)
positional_names.append(name)
else:
self.positiontup.append(name)
- elif not escaped_from:
+ self.positiontup_level[name] = len(self.stack)
+ if not escaped_from:
if _BIND_TRANSLATE_RE.search(name):
# not quite the translate use case as we want to
]
}
)
+ if self.positional:
+ self.cte_level[cte] = cte_level
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
- if self.ctes:
- # In compound query, CTEs are shared at the compound level
- if not is_embedded_select:
- nesting_level = len(self.stack) if not toplevel else None
- text = (
- self._render_cte_clause(nesting_level=nesting_level) + text
+ # In compound query, CTEs are shared at the compound level
+ if self.ctes and (not is_embedded_select or toplevel):
+ nesting_level = len(self.stack) if not toplevel else None
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kwargs.get("visiting_cte"),
)
+ + text
+ )
if select_stmt._suffixes:
text += " " + self._generate_prefixes(
self,
nesting_level=None,
include_following_stack=False,
+ visiting_cte=None,
):
"""
include_following_stack
if not ctes:
return ""
-
ctes_recursive = any([cte.recursive for cte in ctes])
if self.positional:
- self.positiontup = (
- sum([self.cte_positional[cte] for cte in ctes], [])
- + self.positiontup
- )
+ self.cte_order[visiting_cte].extend(ctes)
+
+ if visiting_cte is None and self.cte_order:
+ assert self.positiontup is not None
+
+ def get_nested_positional(cte):
+ if cte in self.cte_order:
+ children = self.cte_order.pop(cte)
+ to_add = list(
+ itertools.chain.from_iterable(
+ get_nested_positional(child_cte)
+ for child_cte in children
+ )
+ )
+ if cte in self.cte_positional:
+ return reorder_positional(
+ self.cte_positional[cte],
+ to_add,
+ self.cte_level[children[0]],
+ )
+ else:
+ return to_add
+ else:
+ return self.cte_positional.get(cte, [])
+
+ def reorder_positional(pos, to_add, level):
+ if not level:
+ return to_add + pos
+ index = 0
+ for index, name in enumerate(reversed(pos)):
+ if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501
+ break
+ return pos[:-index] + to_add + pos[-index:]
+
+ to_add = get_nested_positional(None)
+ self.positiontup = reorder_positional(
+ self.positiontup, to_add, nesting_level
+ )
+
cte_text = self.get_cte_preamble(ctes_recursive) + " "
cte_text += ", \n".join([txt for txt in ctes.values()])
cte_text += "\n "
self._render_cte_clause(
nesting_level=nesting_level,
include_following_stack=True,
+ visiting_cte=kw.get("visiting_cte"),
),
select_text,
)
nesting_level = len(self.stack) if not toplevel else None
text = (
self._render_cte_clause(
- nesting_level=nesting_level, include_following_stack=True
+ nesting_level=nesting_level,
+ include_following_stack=True,
+ visiting_cte=kw.get("visiting_cte"),
)
+ text
)
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)
if self.ctes:
nesting_level = len(self.stack) if not toplevel else None
- text = self._render_cte_clause(nesting_level=nesting_level) + text
+ text = (
+ self._render_cte_clause(
+ nesting_level=nesting_level,
+ visiting_cte=kw.get("visiting_cte"),
+ )
+ + text
+ )
self.stack.pop(-1)
from __future__ import absolute_import
+from collections import defaultdict
import contextlib
+from copy import copy
import re
import sys
import warnings
render_schema_translate=False,
default_schema_name=None,
from_linting=False,
+ check_param_order=True,
):
if use_default_dialect:
dialect = default.DefaultDialect()
if dialect is None:
dialect = config.db.dialect
- elif dialect == "default":
- dialect = default.DefaultDialect()
+ elif dialect == "default" or dialect == "default_qmark":
+ if dialect == "default":
+ dialect = default.DefaultDialect()
+ else:
+ dialect = default.DefaultDialect(paramstyle="qmark")
dialect.supports_default_values = supports_default_values
dialect.supports_default_metavalue = supports_default_metavalue
elif dialect == "default_enhanced":
if checkparams is not None:
eq_(c.construct_params(params), checkparams)
if checkpositional is not None:
- p = c.construct_params(params)
+ p = c.construct_params(params, escape_names=False)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
if check_prefetch is not None:
eq_(c.prefetch, check_prefetch)
},
check_post_param,
)
+ if check_param_order and getattr(c, "params", None):
+
+ def get_dialect(paramstyle, positional):
+ cp = copy(dialect)
+ cp.paramstyle = paramstyle
+ cp.positional = positional
+ return cp
+
+ pyformat_dialect = get_dialect("pyformat", False)
+ pyformat_c = clause.compile(dialect=pyformat_dialect, **kw)
+ stmt = re.sub(r"[\n\t]", "", pyformat_c.string)
+
+ qmark_dialect = get_dialect("qmark", True)
+ qmark_c = clause.compile(dialect=qmark_dialect, **kw)
+ values = list(qmark_c.positiontup)
+ escaped = qmark_c.escaped_bind_names
+
+ for post_param in (
+ qmark_c.post_compile_params | qmark_c.literal_execute_params
+ ):
+ name = qmark_c.bind_names[post_param]
+ if name in values:
+ values = [v for v in values if v != name]
+ positions = []
+ pos_by_value = defaultdict(list)
+ for v in values:
+ try:
+ if v in pos_by_value:
+ start = pos_by_value[v][-1]
+ else:
+ start = 0
+ esc = escaped.get(v, v)
+ pos = stmt.index("%%(%s)s" % (esc,), start) + 2
+ positions.append(pos)
+ pos_by_value[v].append(pos)
+ except ValueError:
+ msg = "Expected to find bindparam %r in %r" % (v, stmt)
+ assert False, msg
+
+ ordered = all(
+ positions[i - 1] < positions[i]
+ for i in range(1, len(positions))
+ )
+
+ expected = [v for _, v in sorted(zip(positions, values))]
+
+ msg = (
+ "Order of parameters %s does not match the order "
+ "in the statement %s. Statement %r" % (values, expected, stmt)
+ )
+
+ is_true(ordered, msg)
class ComparesTables(object):
self.table.c.myid.regexp_replace(
"pattern", "replacement", flags="ig"
),
- "REGEXP_REPLACE(mytable.myid, :myid_1, :myid_3, :myid_2)",
+ "REGEXP_REPLACE(mytable.myid, :myid_1, :myid_2, :myid_3)",
checkparams={
"myid_1": "pattern",
- "myid_3": "replacement",
- "myid_2": "ig",
+ "myid_2": "replacement",
+ "myid_3": "ig",
},
)
self.table.c.myid.regexp_replace(
"pattern", "replacement", flags="ig"
),
- "REGEXP_REPLACE(mytable.myid, %(myid_1)s, %(myid_3)s, %(myid_2)s)",
+ "REGEXP_REPLACE(mytable.myid, %(myid_1)s, %(myid_2)s, %(myid_3)s)",
checkparams={
"myid_1": "pattern",
- "myid_3": "replacement",
- "myid_2": "ig",
+ "myid_2": "replacement",
+ "myid_3": "ig",
},
)
stmt, expected, literal_binds=True, params=params
)
+ standalone_escape = testing.combinations(
+ ("normalname", "normalname"),
+ ("_name", "_name"),
+ ("[BracketsAndCase]", "_BracketsAndCase_"),
+ ("has spaces", "has_spaces"),
+ argnames="paramname, expected",
+ )
+
+ @standalone_escape
+ @testing.variation("use_positional", [True, False])
+ def test_standalone_bindparam_escape(
+ self, paramname, expected, use_positional
+ ):
+ stmt = select(table1.c.myid).where(
+ table1.c.name == bindparam(paramname, value="x")
+ )
+ if use_positional:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable WHERE mytable.name = ?",
+ params={paramname: "y"},
+ checkpositional=("y",),
+ dialect="sqlite",
+ )
+ else:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable WHERE mytable.name = :%s"
+ % (expected,),
+ params={paramname: "y"},
+ checkparams={expected: "y"},
+ dialect="default",
+ )
+
+ @standalone_escape
+ @testing.variation("use_assert_compile", [True, False])
+ @testing.variation("use_positional", [True, False])
+ def test_standalone_bindparam_escape_expanding(
+ self, paramname, expected, use_assert_compile, use_positional
+ ):
+ stmt = select(table1.c.myid).where(
+ table1.c.name.in_(bindparam(paramname, value=["a", "b"]))
+ )
+ if use_assert_compile:
+ if use_positional:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable "
+ "WHERE mytable.name IN (?, ?)",
+ params={paramname: ["y", "z"]},
+ # NOTE: this is what render_postcompile will do right now
+ # if you run construct_params(). render_postcompile mode
+ # is not actually used by the execution internals, it's for
+ # user-facing compilation code. So this is likely a
+ # current limitation of construct_params() which is not
+ # doing the full blown postcompile; just assert that's
+ # what it does for now. it likely should be corrected
+ # to make more sense.
+ checkpositional=(["y", "z"], ["y", "z"]),
+ dialect="sqlite",
+ render_postcompile=True,
+ )
+ else:
+ self.assert_compile(
+ stmt,
+ "SELECT mytable.myid FROM mytable WHERE mytable.name IN "
+ "(:%s_1, :%s_2)" % (expected, expected),
+ params={paramname: ["y", "z"]},
+ # NOTE: this is what render_postcompile will do right now
+ # if you run construct_params(). render_postcompile mode
+ # is not actually used by the execution internals, it's for
+ # user-facing compilation code. So this is likely a
+ # current limitation of construct_params() which is not
+ # doing the full blown postcompile; just assert that's
+ # what it does for now. it likely should be corrected
+ # to make more sense.
+ checkparams={
+ "%s_1" % expected: ["y", "z"],
+ "%s_2" % expected: ["y", "z"],
+ },
+ dialect="default",
+ render_postcompile=True,
+ )
+ else:
+ # this is what DefaultDialect actually does.
+ # this should be matched to DefaultDialect._init_compiled()
+ if use_positional:
+ compiled = stmt.compile(
+ dialect=default.DefaultDialect(paramstyle="qmark")
+ )
+ else:
+ compiled = stmt.compile(dialect=default.DefaultDialect())
+ checkparams = compiled.construct_params(
+ {paramname: ["y", "z"]}, escape_names=False
+ )
+ # nothing actually happened. if the compiler had
+ # render_postcompile set, the
+ # above weird param thing happens
+ eq_(checkparams, {paramname: ["y", "z"]})
+ expanded_state = compiled._process_parameters_for_postcompile(
+ checkparams
+ )
+ eq_(
+ expanded_state.additional_parameters,
+ {"%s_1" % (expected,): "y", "%s_2" % (expected,): "z"},
+ )
+ if use_positional:
+ eq_(
+ expanded_state.positiontup,
+ ["%s_1" % (expected,), "%s_2" % (expected,)],
+ )
+
class UnsupportedTest(fixtures.TestBase):
def test_unsupported_element_str_visit_name(self):
") SELECT cte.outer_cte FROM cte",
)
- def test_nesting_cte_in_recursive_cte(self):
+ @testing.fixture
+ def nesting_cte_in_recursive_cte(self):
nesting_cte = select(literal(1).label("inner_cte")).cte(
"nesting", nesting=True
)
"rec_cte", recursive=True
)
rec_part = select(rec_cte.c.outer_cte).where(
- rec_cte.c.outer_cte == literal(1)
+ rec_cte.c.outer_cte == literal(42)
)
rec_cte = rec_cte.union(rec_part)
stmt = select(rec_cte)
+ return stmt
+
+ def test_nesting_cte_in_recursive_cte_positional(
+ self, nesting_cte_in_recursive_cte
+ ):
self.assert_compile(
- stmt,
+ nesting_cte_in_recursive_cte,
+ "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
+ "(SELECT ? AS inner_cte) "
+ "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+ "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+ "WHERE rec_cte.outer_cte = ?) "
+ "SELECT rec_cte.outer_cte FROM rec_cte",
+ checkpositional=(1, 42),
+ dialect="default_qmark",
+ )
+
+ def test_nesting_cte_in_recursive_cte(self, nesting_cte_in_recursive_cte):
+ self.assert_compile(
+ nesting_cte_in_recursive_cte,
+ "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
+ "(SELECT :param_1 AS inner_cte) "
+ "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+ "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+ "WHERE rec_cte.outer_cte = :param_2) "
+ "SELECT rec_cte.outer_cte FROM rec_cte",
+ checkparams={"param_1": 1, "param_2": 42},
+ )
+
+ @testing.fixture
+ def nesting_cte_in_recursive_cte_w_add_cte(self):
+ nesting_cte = select(literal(1).label("inner_cte")).cte(
+ "nesting", nesting=True
+ )
+
+ rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
+ "rec_cte", recursive=True
+ )
+ rec_part = select(rec_cte.c.outer_cte).where(
+ rec_cte.c.outer_cte == literal(42)
+ )
+ rec_cte = rec_cte.union(rec_part)
+
+ stmt = select(rec_cte)
+ return stmt
+
+ def test_nesting_cte_in_recursive_cte_w_add_cte_positional(
+ self, nesting_cte_in_recursive_cte_w_add_cte
+ ):
+ self.assert_compile(
+ nesting_cte_in_recursive_cte_w_add_cte,
+ "WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
+ "(SELECT ? AS inner_cte) "
+ "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+ "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+ "WHERE rec_cte.outer_cte = ?) "
+ "SELECT rec_cte.outer_cte FROM rec_cte",
+ checkpositional=(1, 42),
+ dialect="default_qmark",
+ )
+
+ def test_nesting_cte_in_recursive_cte_w_add_cte(
+ self, nesting_cte_in_recursive_cte_w_add_cte
+ ):
+ self.assert_compile(
+ nesting_cte_in_recursive_cte_w_add_cte,
"WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
"(SELECT :param_1 AS inner_cte) "
"SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
"SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
"WHERE rec_cte.outer_cte = :param_2) "
"SELECT rec_cte.outer_cte FROM rec_cte",
+ checkparams={"param_1": 1, "param_2": 42},
)
def test_recursive_nesting_cte_in_cte(self):
"SELECT cte.outer_cte FROM cte",
)
- def test_same_nested_cte_is_not_generated_twice(self):
+ @testing.fixture
+ def same_nested_cte_is_not_generated_twice(self):
# Same = name and query
nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
"nesting_cte", nesting=True
)
select_add_cte = select(
- (nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value")
+ (nesting_cte_used_twice.c.inner_cte_1 + 2).label("next_value")
).cte("nesting_2", nesting=True)
union_cte = (
select(
- (nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value")
+ (nesting_cte_used_twice.c.inner_cte_1 - 3).label("next_value")
)
.union(select(select_add_cte))
.cte("wrapper", nesting=True)
.add_cte(nesting_cte_used_twice)
.union(select(nesting_cte_used_twice))
)
+ return stmt
+ def test_same_nested_cte_is_not_generated_twice_positional(
+ self, same_nested_cte_is_not_generated_twice
+ ):
self.assert_compile(
- stmt,
+ same_nested_cte_is_not_generated_twice,
+ "WITH nesting_cte AS "
+ "(SELECT ? AS inner_cte_1)"
+ ", wrapper AS "
+ "(WITH nesting_2 AS "
+ "(SELECT nesting_cte.inner_cte_1 + ? "
+ "AS next_value "
+ "FROM nesting_cte)"
+ " SELECT nesting_cte.inner_cte_1 - ? "
+ "AS next_value "
+ "FROM nesting_cte UNION SELECT nesting_2.next_value "
+ "AS next_value FROM nesting_2)"
+ " SELECT wrapper.next_value "
+ "FROM wrapper UNION SELECT nesting_cte.inner_cte_1 "
+ "FROM nesting_cte",
+ checkpositional=(1, 2, 3),
+ dialect="default_qmark",
+ )
+
+ def test_same_nested_cte_is_not_generated_twice(
+ self, same_nested_cte_is_not_generated_twice
+ ):
+ self.assert_compile(
+ same_nested_cte_is_not_generated_twice,
"WITH nesting_cte AS "
"(SELECT :param_1 AS inner_cte_1)"
", wrapper AS "
"FROM nesting_cte)"
" SELECT nesting_cte.inner_cte_1 - :inner_cte_1_1 "
"AS next_value "
- "FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value "
- "FROM nesting_2)"
+ "FROM nesting_cte UNION SELECT nesting_2.next_value "
+ "AS next_value FROM nesting_2)"
" SELECT wrapper.next_value "
"FROM wrapper UNION SELECT nesting_cte.inner_cte_1 "
"FROM nesting_cte",
+ checkparams={
+ "param_1": 1,
+ "inner_cte_1_2": 2,
+ "inner_cte_1_1": 3,
+ },
)
- def test_recursive_nesting_cte_in_recursive_cte(self):
+ @testing.fixture
+ def recursive_nesting_cte_in_recursive_cte(self):
nesting_cte = select(literal(1).label("inner_cte")).cte(
"nesting", nesting=True, recursive=True
)
nesting_rec_part = select(nesting_cte.c.inner_cte).where(
- nesting_cte.c.inner_cte == literal(1)
+ nesting_cte.c.inner_cte == literal(2)
)
nesting_cte = nesting_cte.union(nesting_rec_part)
"rec_cte", recursive=True
)
rec_part = select(rec_cte.c.outer_cte).where(
- rec_cte.c.outer_cte == literal(1)
+ rec_cte.c.outer_cte == literal(3)
)
rec_cte = rec_cte.union(rec_part)
stmt = select(rec_cte)
+ return stmt
+
+ def test_recursive_nesting_cte_in_recursive_cte_positional(
+ self, recursive_nesting_cte_in_recursive_cte
+ ):
self.assert_compile(
- stmt,
+ recursive_nesting_cte_in_recursive_cte,
+ "WITH RECURSIVE rec_cte(outer_cte) AS ("
+ "WITH RECURSIVE nesting(inner_cte) AS "
+ "(SELECT ? AS inner_cte UNION "
+ "SELECT nesting.inner_cte AS inner_cte FROM nesting "
+ "WHERE nesting.inner_cte = ?) "
+ "SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
+ "SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
+ "WHERE rec_cte.outer_cte = ?) "
+ "SELECT rec_cte.outer_cte FROM rec_cte",
+ checkpositional=(1, 2, 3),
+ dialect="default_qmark",
+ )
+
+ def test_recursive_nesting_cte_in_recursive_cte(
+ self, recursive_nesting_cte_in_recursive_cte
+ ):
+ self.assert_compile(
+ recursive_nesting_cte_in_recursive_cte,
"WITH RECURSIVE rec_cte(outer_cte) AS ("
"WITH RECURSIVE nesting(inner_cte) AS "
"(SELECT :param_1 AS inner_cte UNION "
"SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
"WHERE rec_cte.outer_cte = :param_3) "
"SELECT rec_cte.outer_cte FROM rec_cte",
+ checkparams={"param_1": 1, "param_2": 2, "param_3": 3},
)
def test_select_from_insert_cte_with_nesting(self):
") SELECT cte.outer_cte FROM cte",
)
- def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self):
+ @testing.fixture
+ def cte_in_compound_select(self):
+ upper = select(literal(1).label("z"))
+
+ lower_a_cte = select(literal(2).label("x")).cte("xx", nesting=True)
+ lower_a = select(literal(3).label("y")).add_cte(lower_a_cte)
+ lower_b = select(literal(4).label("w"))
+
+ stmt = upper.union_all(lower_a.union_all(lower_b))
+ return stmt
+
+ def test_cte_in_compound_select_positional(self, cte_in_compound_select):
+ self.assert_compile(
+ cte_in_compound_select,
+ "SELECT ? AS z UNION ALL (WITH xx AS "
+ "(SELECT ? AS x) "
+ "SELECT ? AS y UNION ALL SELECT ? AS w)",
+ checkpositional=(1, 2, 3, 4),
+ dialect="default_qmark",
+ )
+
+ def test_cte_in_compound_select(self, cte_in_compound_select):
+ self.assert_compile(
+ cte_in_compound_select,
+ "SELECT :param_1 AS z UNION ALL (WITH xx AS "
+ "(SELECT :param_2 AS x) "
+ "SELECT :param_3 AS y UNION ALL SELECT :param_4 AS w)",
+ checkparams={
+ "param_1": 1,
+ "param_2": 2,
+ "param_3": 3,
+ "param_4": 4,
+ },
+ )
+
+ @testing.fixture
+ def recursive_cte_referenced_multiple_times_with_nesting_cte(self):
rec_root = select(literal(1).label("the_value")).cte(
"recursive_cte", recursive=True
)
exists(
select(rec_root_ref.c.the_value)
.where(rec_root_ref.c.the_value < 10)
- .limit(1)
+ .limit(5)
).label("val")
).cte("should_continue", nesting=True)
rec_cte = rec_root.union_all(rec_part)
stmt = rec_cte.select()
+ return stmt
+ def test_recursive_cte_referenced_multiple_times_with_nesting_cte_pos(
+ self, recursive_cte_referenced_multiple_times_with_nesting_cte
+ ):
self.assert_compile(
- stmt,
+ recursive_cte_referenced_multiple_times_with_nesting_cte,
+ "WITH RECURSIVE recursive_cte(the_value) AS ("
+ "SELECT ? AS the_value UNION ALL ("
+ "WITH allow_multiple_ref AS ("
+ "SELECT recursive_cte.the_value AS the_value "
+ "FROM recursive_cte)"
+ ", should_continue AS (SELECT EXISTS ("
+ "SELECT allow_multiple_ref.the_value FROM allow_multiple_ref"
+ " WHERE allow_multiple_ref.the_value < ?"
+ " LIMIT ?) AS val) "
+ "SELECT allow_multiple_ref.the_value * ? AS anon_1"
+ " FROM allow_multiple_ref, should_continue "
+ "WHERE should_continue.val != 1"
+ " UNION ALL SELECT allow_multiple_ref.the_value * ?"
+ " AS anon_2 FROM allow_multiple_ref, should_continue"
+ " WHERE should_continue.val != 1))"
+ " SELECT recursive_cte.the_value FROM recursive_cte",
+ checkpositional=(1, 10, 5, 2, 3),
+ dialect="default_qmark",
+ )
+
+ def test_recursive_cte_referenced_multiple_times_with_nesting_cte(
+ self, recursive_cte_referenced_multiple_times_with_nesting_cte
+ ):
+ self.assert_compile(
+ recursive_cte_referenced_multiple_times_with_nesting_cte,
"WITH RECURSIVE recursive_cte(the_value) AS ("
"SELECT :param_1 AS the_value UNION ALL ("
"WITH allow_multiple_ref AS ("
- "SELECT recursive_cte.the_value AS the_value FROM recursive_cte)"
+ "SELECT recursive_cte.the_value AS the_value "
+ "FROM recursive_cte)"
", should_continue AS (SELECT EXISTS ("
"SELECT allow_multiple_ref.the_value FROM allow_multiple_ref"
" WHERE allow_multiple_ref.the_value < :the_value_2"
" AS anon_2 FROM allow_multiple_ref, should_continue"
" WHERE should_continue.val != true))"
" SELECT recursive_cte.the_value FROM recursive_cte",
+ checkparams={
+ "param_1": 1,
+ "param_2": 5,
+ "the_value_2": 10,
+ "the_value_1": 2,
+ "the_value_3": 3,
+ },
)
("name with~~tildes~~",),
argnames="name",
)
- def test_bindparam_key_proc_for_copies(self, meth, name):
+ @testing.combinations(True, False, argnames="positional")
+ def test_bindparam_key_proc_for_copies(self, meth, name, positional):
r"""test :ticket:`6249`.
Revised for :ticket:`8056`.
token = re.sub(r"[%\(\) \$\[\]]", "_", name)
- self.assert_compile(
- expr,
- '"%(name)s" IN (:%(token)s_1_1, '
- ":%(token)s_1_2, :%(token)s_1_3)" % {"name": name, "token": token},
- render_postcompile=True,
- dialect="default",
- )
+ if positional:
+ self.assert_compile(
+ expr,
+ '"%(name)s" IN (?, ?, ?)' % {"name": name},
+ checkpositional=(1, 2, 3),
+ render_postcompile=True,
+ dialect="default_qmark",
+ )
+ else:
+ tokens = ["%s_1_%s" % (token, i) for i in range(1, 4)]
+ self.assert_compile(
+ expr,
+ '"%(name)s" IN (:%(token)s_1_1, '
+ ":%(token)s_1_2, :%(token)s_1_3)"
+ % {"name": name, "token": token},
+ checkparams=dict(zip(tokens, [1, 2, 3])),
+ render_postcompile=True,
+ dialect="default",
+ )
def test_expanding_in_bindparam_safe_to_clone(self):
expr = column("x").in_([1, 2, 3])
"OVER (PARTITION BY mytable.description RANGE BETWEEN :param_1 "
"FOLLOWING AND :param_2 FOLLOWING) "
"AS anon_1 FROM mytable",
+ checkparams={"name_1": "foo", "param_1": 1, "param_2": 5},
+ )
+
+ def test_funcfilter_windowing_range_positional(self):
+ self.assert_compile(
+ select(
+ func.rank()
+ .filter(table1.c.name > "foo")
+ .over(range_=(1, 5), partition_by=["description"])
+ ),
+ "SELECT rank() FILTER (WHERE mytable.name > ?) "
+ "OVER (PARTITION BY mytable.description RANGE BETWEEN ? "
+ "FOLLOWING AND ? FOLLOWING) "
+ "AS anon_1 FROM mytable",
+ checkpositional=("foo", 1, 5),
+ dialect="default_qmark",
)
def test_funcfilter_windowing_rows(self):