From: Federico Caselli Date: Thu, 16 Oct 2025 18:59:30 +0000 (+0200) Subject: Improve params implementation X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=33264146aef33e8f564ec2b8fc3730d59889bf18;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve params implementation Added new implementation for the :meth:`.Select.params` method and that of similar statements, via a new statement-only :meth:`.ExecutableStatement.params` method which works more efficiently and correctly than the previous implementations available from :class:`.ClauseElement`, by assocating the given parameter dictionary with the statement overall rather than cloning the statement and rewriting its bound parameters. The :meth:`_sql.ClauseElement.params` and :meth:`_sql.ClauseElement.unique_params` methods, when called on an object that does not implement :class:`.ExecutableStatement`, will continue to work the old way of cloning the object, and will emit a deprecation warning. This issue both resolves the architectural / performance concerns of :ticket:`7066` and also provides correct ORM compatibility for functions like :func:`_orm.aliased`, reported by :ticket:`12915`. Fixes: #7066 Change-Id: I6543c7d0f4da3232b3641fb172c24c446f02f52a --- diff --git a/doc/build/changelog/migration_21.rst b/doc/build/changelog/migration_21.rst index 53d5c90ba1..8dd006798f 100644 --- a/doc/build/changelog/migration_21.rst +++ b/doc/build/changelog/migration_21.rst @@ -497,7 +497,6 @@ E.g.:: New Features and Improvements - Core ===================================== - .. _change_10635: ``Row`` now represents individual column types directly without ``Tuple`` @@ -618,6 +617,51 @@ not the database portion:: :ticket:`11234` +.. _change_7066: + +Improved ``params()`` implementation for executable statements +-------------------------------------------------------------- + +The :meth:`_sql.ClauseElement.params` and :meth:`_sql.ClauseElement.unique_params` +methods have been deprecated in favor of a new implementation on executable +statements that provides improved performance and better integration with +ORM-enabled statements. + +Executable statement objects like :class:`_sql.Select`, :class:`_sql.CompoundSelect`, +and :class:`_sql.TextClause` now provide an improved :meth:`_sql.ExecutableStatement.params` +method that avoids a full cloned traversal of the statement tree. Instead, parameters +are stored directly on the statement object and efficiently merged during compilation +and/or cache key traversal. + +The new implementation provides several benefits: + +* **Better performance** - Parameters are stored in a simple dictionary rather than + requiring a full statement tree traversal with cloning +* **Proper caching integration** - Parameters are correctly integrated into SQLAlchemy's + cache key system via ``_generate_cache_key()`` +* **ORM statement compatibility** - Works correctly with ORM-enabled statements, including + ORM entities used with :func:`_orm.aliased`, subqueries, CTEs, etc. + +Use of :meth:`_sql.ExecutableStatement.params` is unchanged, provided the given +object is a statement object such as :func:`_sql.select`:: + + stmt = select(table).where(table.c.data == bindparam("x")) + + # Execute with parameter value + result = connection.execute(stmt.params(x=5)) + + # Can be chained and used in subqueries + stmt2 = stmt.params(x=6).subquery().select() + result = connection.execute(stmt2.params(x=7)) # Uses x=7 + +The deprecated :meth:`_sql.ClauseElement.params` and :meth:`_sql.ClauseElement.unique_params` +methods on non-executable elements like :class:`_sql.ColumnElement` and general +:class:`_sql.ClauseElement` instances will continue to work during the deprecation +period but will emit deprecation warnings. + +:ticket:`7066` + + .. _change_4950: CREATE TABLE AS SELECT Support diff --git a/doc/build/changelog/unreleased_21/7066.rst b/doc/build/changelog/unreleased_21/7066.rst new file mode 100644 index 0000000000..27281b6ba3 --- /dev/null +++ b/doc/build/changelog/unreleased_21/7066.rst @@ -0,0 +1,21 @@ +.. change:: + :tags: change, sql + :tickets: 7066, 12915 + + Added new implementation for the :meth:`.Select.params` method and that of + similar statements, via a new statement-only + :meth:`.ExecutableStatement.params` method which works more efficiently and + correctly than the previous implementations available from + :class:`.ClauseElement`, by assocating the given parameter dictionary with + the statement overall rather than cloning the statement and rewriting its + bound parameters. The :meth:`_sql.ClauseElement.params` and + :meth:`_sql.ClauseElement.unique_params` methods, when called on an object + that does not implement :class:`.ExecutableStatement`, will continue to + work the old way of cloning the object, and will emit a deprecation + warning. This issue both resolves the architectural / performance + concerns of :ticket:`7066` and also provides correct ORM compatibility for + functions like :func:`_orm.aliased`, reported by :ticket:`12915`. + + .. seealso:: + + :ref:`change_7066` diff --git a/doc/build/core/selectable.rst b/doc/build/core/selectable.rst index d7c2b56c8b..38d47aa657 100644 --- a/doc/build/core/selectable.rst +++ b/doc/build/core/selectable.rst @@ -89,6 +89,9 @@ The classes here are generated using the constructors listed at .. autoclass:: Executable :members: +.. autoclass:: ExecutableStatement + :members: + .. autoclass:: Exists :members: diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 0d8bf6f08b..57cb0fc2a9 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1641,13 +1641,15 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): "compiled_cache", self.engine._compiled_cache ) - compiled_sql, extracted_params, cache_hit = elem._compile_w_cache( - dialect=dialect, - compiled_cache=compiled_cache, - column_keys=keys, - for_executemany=for_executemany, - schema_translate_map=schema_translate_map, - linting=self.dialect.compiler_linting | compiler.WARN_LINTING, + compiled_sql, extracted_params, param_dict, cache_hit = ( + elem._compile_w_cache( + dialect=dialect, + compiled_cache=compiled_cache, + column_keys=keys, + for_executemany=for_executemany, + schema_translate_map=schema_translate_map, + linting=self.dialect.compiler_linting | compiler.WARN_LINTING, + ) ) ret = self._execute_context( dialect, @@ -1660,6 +1662,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): elem, extracted_params, cache_hit=cache_hit, + param_dict=param_dict, ) if has_events: self.dispatch.after_execute( diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 367455d5e2..1ffec4092d 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -1329,6 +1329,7 @@ class DefaultExecutionContext(ExecutionContext): invoked_statement: Executable, extracted_parameters: Optional[Sequence[BindParameter[Any]]], cache_hit: CacheStats = CacheStats.CACHING_DISABLED, + param_dict: _CoreSingleExecuteParams | None = None, ) -> ExecutionContext: """Initialize execution context for a Compiled construct.""" @@ -1417,6 +1418,7 @@ class DefaultExecutionContext(ExecutionContext): compiled.construct_params( extracted_parameters=extracted_parameters, escape_names=False, + _collected_params=param_dict, ) ] else: @@ -1426,6 +1428,7 @@ class DefaultExecutionContext(ExecutionContext): escape_names=False, _group_number=grp, extracted_parameters=extracted_parameters, + _collected_params=param_dict, ) for grp, m in enumerate(parameters) ] diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 2dd6e950a7..8f26eb2c5d 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -54,6 +54,7 @@ from ..sql.base import _select_iterables from ..sql.base import CacheableOptions from ..sql.base import CompileState from ..sql.base import Executable +from ..sql.base import ExecutableStatement from ..sql.base import Generative from ..sql.base import Options from ..sql.dml import UpdateBase @@ -981,7 +982,7 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[Unpack[_Ts]]): _traverse_internals = [ ("_raw_columns", InternalTraversal.dp_clauseelement_list), ("element", InternalTraversal.dp_clauseelement), - ] + Executable._executable_traverse_internals + ] + ExecutableStatement._executable_traverse_internals _cache_key_traversal = _traverse_internals + [ ("_compile_options", InternalTraversal.dp_has_cache_key) diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 1231d166cd..c49c0f6c3a 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -11,6 +11,7 @@ from ._typing import ColumnExpressionArgument as ColumnExpressionArgument from ._typing import NotNullable as NotNullable from ._typing import Nullable as Nullable from .base import Executable as Executable +from .base import ExecutableStatement as ExecutableStatement from .base import SyntaxExtension as SyntaxExtension from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS from .compiler import FROM_LINTING as FROM_LINTING diff --git a/lib/sqlalchemy/sql/_util_cy.py b/lib/sqlalchemy/sql/_util_cy.py index 8d4ef542b9..ddb8680196 100644 --- a/lib/sqlalchemy/sql/_util_cy.py +++ b/lib/sqlalchemy/sql/_util_cy.py @@ -15,6 +15,7 @@ from typing import Union if TYPE_CHECKING: from .cache_key import CacheConst + from ..engine.interfaces import _CoreSingleExecuteParams # START GENERATED CYTHON IMPORT # This section is automatically generated by the script tools/cython_imports.py @@ -67,13 +68,12 @@ class prefix_anon_map(Dict[str, str]): return value +_AM_KEY = Union[int, str, "CacheConst"] +_AM_VALUE = Union[int, Literal[True], "_CoreSingleExecuteParams"] + + @cython.cclass -class anon_map( - Dict[ - Union[int, str, "Literal[CacheConst.NO_CACHE]"], - Union[int, Literal[True]], - ] -): +class anon_map(Dict[_AM_KEY, _AM_VALUE]): """A map that creates new keys for missing key access. Produces an incrementing sequence given a series of unique keys. @@ -96,9 +96,7 @@ class anon_map( @cython.cfunc # type:ignore[misc] @cython.inline # type:ignore[misc] - def _add_missing( - self: anon_map, key: Union[int, str, "Literal[CacheConst.NO_CACHE]"], / - ) -> int: + def _add_missing(self: anon_map, key: _AM_KEY, /) -> int: val: int = self._index self._index += 1 self_dict: dict = self # type: ignore[type-arg] @@ -116,11 +114,7 @@ class anon_map( if cython.compiled: - def __getitem__( - self: anon_map, - key: Union[int, str, "Literal[CacheConst.NO_CACHE]"], - /, - ) -> Union[int, Literal[True]]: + def __getitem__(self: anon_map, key: _AM_KEY, /) -> _AM_VALUE: self_dict: dict = self # type: ignore[type-arg] if key in self_dict: @@ -128,7 +122,5 @@ class anon_map( else: return self._add_missing(key) # type:ignore[no-any-return] - def __missing__( - self: anon_map, key: Union[int, str, "Literal[CacheConst.NO_CACHE]"], / - ) -> int: + def __missing__(self: anon_map, key: _AM_KEY, /) -> int: return self._add_missing(key) # type:ignore[no-any-return] diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index a07b1204b9..67eb44fc8d 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -56,6 +56,7 @@ from .visitors import InternalTraversal from .. import event from .. import exc from .. import util +from ..util import EMPTY_DICT from ..util import HasMemoized as HasMemoized from ..util import hybridmethod from ..util.typing import Self @@ -92,6 +93,7 @@ if TYPE_CHECKING: from ..engine import Connection from ..engine import CursorResult from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _ExecuteOptions from ..engine.interfaces import _ImmutableExecuteOptions from ..engine.interfaces import CacheStats @@ -1010,7 +1012,7 @@ class CacheableOptions(Options, HasCacheKey): @hybridmethod def _generate_cache_key(self) -> Optional[CacheKey]: - return HasCacheKey._generate_cache_key_for_object(self) + return HasCacheKey._generate_cache_key(self) class ExecutableOption(HasCopyInternals): @@ -1287,8 +1289,11 @@ class Executable(roles.StatementRole): for_executemany: bool = False, schema_translate_map: Optional[SchemaTranslateMapType] = None, **kw: Any, - ) -> Tuple[ - Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats + ) -> tuple[ + Compiled, + Sequence[BindParameter[Any]] | None, + _CoreSingleExecuteParams | None, + CacheStats, ]: ... def _execute_on_connection( @@ -1540,6 +1545,50 @@ class Executable(roles.StatementRole): return self._execution_options +class ExecutableStatement(Executable): + """Executable subclass that implements a lightweight version of ``params`` + that avoids a full cloned traverse. + + .. versionadded:: 2.1 + + """ + + _params: util.immutabledict[str, Any] = EMPTY_DICT + + _executable_traverse_internals = ( + Executable._executable_traverse_internals + + [("_params", InternalTraversal.dp_params)] + ) + + @_generative + def params( + self, + __optionaldict: _CoreSingleExecuteParams | None = None, + /, + **kwargs: Any, + ) -> Self: + """Return a copy with the provided bindparam values. + + Returns a copy of this Executable with bindparam values set + to the given dictionary:: + + >>> clause = column("x") + bindparam("foo") + >>> print(clause.compile().params) + {'foo': None} + >>> print(clause.params({"foo": 7}).compile().params) + {'foo': 7} + + """ + if __optionaldict: + kwargs.update(__optionaldict) + self._params = ( + util.immutabledict(kwargs) + if not self._params + else self._params | kwargs + ) + return self + + class SchemaEventTarget(event.EventTarget): """Base class for elements that are the targets of :class:`.DDLEvents` events. diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index f44ca26886..e35c201f21 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -13,6 +13,7 @@ import typing from typing import Any from typing import Callable from typing import Dict +from typing import Final from typing import Iterable from typing import Iterator from typing import List @@ -50,9 +51,10 @@ class _CacheKeyTraversalDispatchType(Protocol): class CacheConst(enum.Enum): NO_CACHE = 0 + PARAMS = 1 -NO_CACHE = CacheConst.NO_CACHE +NO_CACHE: Final = CacheConst.NO_CACHE _CacheKeyTraversalType = Union[ @@ -384,21 +386,11 @@ class HasCacheKey: return None else: assert key is not None - return CacheKey(key, bindparams) - - @classmethod - def _generate_cache_key_for_object( - cls, obj: HasCacheKey - ) -> Optional[CacheKey]: - bindparams: List[BindParameter[Any]] = [] - - _anon_map = anon_map() - key = obj._gen_cache_key(_anon_map, bindparams) - if NO_CACHE in _anon_map: - return None - else: - assert key is not None - return CacheKey(key, bindparams) + return CacheKey( + key, + bindparams, + _anon_map.get(CacheConst.PARAMS), # type: ignore[arg-type] + ) class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey): @@ -432,6 +424,7 @@ class CacheKey(NamedTuple): key: Tuple[Any, ...] bindparams: Sequence[BindParameter[Any]] + params: _CoreSingleExecuteParams | None # can't set __hash__ attribute because it interferes # with namedtuple @@ -485,8 +478,8 @@ class CacheKey(NamedTuple): @classmethod def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str: - ck1 = CacheKey(left, []) - ck2 = CacheKey(right, []) + ck1 = CacheKey(left, [], None) + ck2 = CacheKey(right, [], None) return ck1._diff(ck2) def _whats_different(self, other: CacheKey) -> Iterator[str]: @@ -1053,5 +1046,21 @@ class _CacheKeyTraversal(HasTraversalDispatch): anon_map[NO_CACHE] = True return () + def visit_params( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if obj: + if CacheConst.PARAMS in anon_map: + to_set = anon_map[CacheConst.PARAMS] | obj + else: + to_set = obj + anon_map[CacheConst.PARAMS] = to_set + return () + _cache_key_traversal_visitor = _CacheKeyTraversal() diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index c666fab025..ee13cf8da0 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -40,6 +40,7 @@ from typing import Callable from typing import cast from typing import ClassVar from typing import Dict +from typing import Final from typing import FrozenSet from typing import Iterable from typing import Iterator @@ -93,6 +94,7 @@ if typing.TYPE_CHECKING: from .base import _AmbiguousTableNameMap from .base import CompileState from .base import Executable + from .base import ExecutableStatement from .cache_key import CacheKey from .ddl import CreateTableAs from .ddl import ExecutableDDLElement @@ -1390,6 +1392,8 @@ class SQLCompiler(Compiled): _positional_pattern = re.compile( f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}" ) + _collect_params: Final[bool] + _collected_params: util.immutabledict[str, Any] @classmethod def _init_compiler_cls(cls): @@ -1489,6 +1493,11 @@ class SQLCompiler(Compiled): # dialect.label_length or dialect.max_identifier_length self.truncated_names: Dict[Tuple[str, str], str] = {} self._truncated_counters: Dict[str, int] = {} + if not cache_key: + self._collect_params = True + self._collected_params = util.EMPTY_DICT + else: + self._collect_params = False # type: ignore[misc] Compiled.__init__(self, dialect, statement, **kwargs) @@ -1627,6 +1636,13 @@ class SQLCompiler(Compiled): def _global_attributes(self) -> Dict[Any, Any]: return {} + def _add_to_params(self, item: ExecutableStatement) -> None: + # assumes that this is called before traversing the statement + # so the call happens outer to inner, meaning that existing params + # take precedence + if item._params: + self._collected_params = item._params | self._collected_params + @util.memoized_instancemethod def _init_cte_state(self) -> MutableMapping[CTE, str]: """Initialize collections related to CTEs only if @@ -1874,8 +1890,19 @@ class SQLCompiler(Compiled): _group_number: Optional[int] = None, _check: bool = True, _no_postcompile: bool = False, + _collected_params: _CoreSingleExecuteParams | None = None, ) -> _MutableCoreSingleExecuteParams: """return a dictionary of bind parameter keys and values""" + if _collected_params is not None: + assert not self._collect_params + elif self._collect_params: + _collected_params = self._collected_params + + if _collected_params: + if not params: + params = _collected_params + else: + params = {**_collected_params, **params} if self._render_postcompile and not _no_postcompile: assert self._post_compile_expanded_state is not None @@ -2704,6 +2731,9 @@ class SQLCompiler(Compiled): return text def visit_textclause(self, textclause, add_to_result_map=None, **kw): + if self._collect_params: + self._add_to_params(textclause) + def do_bindparam(m): name = m.group(1) if name in textclause._bindparams: @@ -2731,6 +2761,8 @@ class SQLCompiler(Compiled): def visit_textual_select( self, taf, compound_index=None, asfrom=False, **kw ): + if self._collect_params: + self._add_to_params(taf) toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] @@ -3026,6 +3058,8 @@ class SQLCompiler(Compiled): add_to_result_map: Optional[_ResultMapAppender] = None, **kwargs: Any, ) -> str: + if self._collect_params: + self._add_to_params(func) if add_to_result_map is not None: add_to_result_map(func.name, func.name, (func.name,), func.type) @@ -3081,6 +3115,8 @@ class SQLCompiler(Compiled): def visit_compound_select( self, cs, asfrom=False, compound_index=None, **kwargs ): + if self._collect_params: + self._add_to_params(cs) toplevel = not self.stack compile_state = cs._compile_state_factory(cs, self, **kwargs) @@ -4870,6 +4906,8 @@ class SQLCompiler(Compiled): "the translate_select_structure hook for structural " "translations of SELECT objects" ) + if self._collect_params: + self._add_to_params(select_stmt) # initial setup of SELECT. the compile_state_factory may now # be creating a totally different SELECT from the one that was diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 58a8c3c8e8..be8781bb76 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -55,6 +55,7 @@ if typing.TYPE_CHECKING: from .schema import Sequence as Sequence # noqa: F401 from .schema import Table from ..engine.base import Connection + from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import CacheStats from ..engine.interfaces import CompiledCacheType from ..engine.interfaces import Dialect @@ -89,8 +90,11 @@ class BaseDDLElement(ClauseElement): for_executemany: bool = False, schema_translate_map: Optional[SchemaTranslateMapType] = None, **kw: Any, - ) -> Tuple[ - Compiled, Optional[typing_Sequence[BindParameter[Any]]], CacheStats + ) -> tuple[ + Compiled, + typing_Sequence[BindParameter[Any]] | None, + _CoreSingleExecuteParams | None, + CacheStats, ]: raise NotImplementedError() diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 54e95d20a6..e85e98a6f8 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -48,6 +48,7 @@ from .base import ColumnSet from .base import CompileState from .base import DialectKWArgs from .base import Executable +from .base import ExecutableStatement from .base import Generative from .base import HasCompileState from .base import HasSyntaxExtensions @@ -1249,7 +1250,7 @@ class Insert(ValuesBase, HasSyntaxExtensions[Literal["post_values"]]): ] + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals - + Executable._executable_traverse_internals + + ExecutableStatement._executable_traverse_internals + HasCTE._has_ctes_traverse_internals ) @@ -1614,7 +1615,7 @@ class Update( ] + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals - + Executable._executable_traverse_internals + + ExecutableStatement._executable_traverse_internals + HasCTE._has_ctes_traverse_internals ) @@ -1815,7 +1816,7 @@ class Delete( ] + HasPrefixes._has_prefixes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals - + Executable._executable_traverse_internals + + ExecutableStatement._executable_traverse_internals + HasCTE._has_ctes_traverse_internals ) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index ddbfd00c69..7a640ccacc 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -58,6 +58,7 @@ from .base import _expand_cloned from .base import _generative from .base import _NoArg from .base import Executable +from .base import ExecutableStatement from .base import Generative from .base import HasMemoized from .base import Immutable @@ -77,6 +78,7 @@ from .visitors import Visitable from .. import exc from .. import inspection from .. import util +from ..util import deprecated from ..util import HasMemoized_ro_memoized_attribute from ..util import TypingOnly from ..util.typing import Self @@ -117,6 +119,7 @@ if typing.TYPE_CHECKING: from ..engine import Connection from ..engine import Dialect from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import CacheStats from ..engine.interfaces import CompiledCacheType from ..engine.interfaces import CoreExecuteOptionsParameter @@ -610,6 +613,12 @@ class ClauseElement( """ return self._replace_params(False, __optionaldict, kwargs) + @deprecated( + "2.1", + "The params() and unique_params() methods on non-statement " + "ClauseElement objects are deprecated; params() is now limited to " + "statement level objects such as select(), insert(), union(), etc. ", + ) def _replace_params( self, unique: bool, @@ -691,8 +700,11 @@ class ClauseElement( for_executemany: bool = False, schema_translate_map: Optional[SchemaTranslateMapType] = None, **kw: Any, - ) -> typing_Tuple[ - Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats + ) -> tuple[ + Compiled, + Sequence[BindParameter[Any]] | None, + _CoreSingleExecuteParams | None, + CacheStats, ]: elem_cache_key: Optional[CacheKey] @@ -706,7 +718,7 @@ class ClauseElement( if TYPE_CHECKING: assert compiled_cache is not None - cache_key, extracted_params = elem_cache_key + cache_key, extracted_params, param_dict = elem_cache_key key = ( dialect, cache_key, @@ -726,19 +738,26 @@ class ClauseElement( schema_translate_map=schema_translate_map, **kw, ) + # ensure that params of the current statement are not + # left in the cache + assert not compiled_sql._collect_params # type: ignore[attr-defined] # noqa: E501 compiled_cache[key] = compiled_sql else: cache_hit = dialect.CACHE_HIT else: + param_dict = None extracted_params = None compiled_sql = self._compiler( dialect, - cache_key=elem_cache_key, + cache_key=None, column_keys=column_keys, for_executemany=for_executemany, schema_translate_map=schema_translate_map, **kw, ) + # here instead the params need to be extracted, since we don't + # have them otherwise + assert compiled_sql._collect_params # type: ignore[attr-defined] # noqa: E501 if not dialect._supports_statement_cache: cache_hit = dialect.NO_DIALECT_SUPPORT @@ -747,7 +766,7 @@ class ClauseElement( else: cache_hit = dialect.NO_CACHE_KEY - return compiled_sql, extracted_params, cache_hit + return compiled_sql, extracted_params, param_dict, cache_hit def __invert__(self): # undocumented element currently used by the ORM for @@ -2314,7 +2333,7 @@ class TextClause( roles.SelectStatementRole, roles.InElementRole, Generative, - Executable, + ExecutableStatement, DQLDMLClauseElement, roles.BinaryElementRole[Any], inspection.Inspectable["TextClause"], @@ -2343,7 +2362,7 @@ class TextClause( _traverse_internals: _TraverseInternalsType = [ ("_bindparams", InternalTraversal.dp_string_clauseelement_dict), ("text", InternalTraversal.dp_string), - ] + Executable._executable_traverse_internals + ] + ExecutableStatement._executable_traverse_internals _is_text_clause = True diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index d85142ed90..60029dbf4f 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -66,6 +66,7 @@ from .base import _from_objects as _from_objects from .base import _select_iterables as _select_iterables from .base import ColumnCollection as ColumnCollection from .base import Executable as Executable +from .base import ExecutableStatement as ExecutableStatement from .cache_key import CacheKey as CacheKey from .dml import Delete as Delete from .dml import Insert as Insert diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index f4437bb980..8671986a6b 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -36,7 +36,7 @@ from . import util as sqlutil from ._typing import is_table_value_type from .base import _entity_namespace from .base import ColumnCollection -from .base import Executable +from .base import ExecutableStatement from .base import Generative from .base import HasMemoized from .elements import _type_from_args @@ -114,7 +114,9 @@ def register_function( reg[identifier] = fn -class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): +class FunctionElement( + ColumnElement[_T], ExecutableStatement, FromClause, Generative +): """Base for SQL function-oriented constructs. This is a `generic type `_, @@ -140,7 +142,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): ("clause_expr", InternalTraversal.dp_clauseelement), ("_with_ordinality", InternalTraversal.dp_boolean), ("_table_value_type", InternalTraversal.dp_has_cache_key), - ] + Executable._executable_traverse_internals + ] + ExecutableStatement._executable_traverse_internals packagenames: Tuple[str, ...] = () diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 731f7e5932..02fcd34131 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -37,6 +37,7 @@ from . import schema from . import visitors from .base import _clone from .base import Executable +from .base import ExecutableStatement from .base import Options from .cache_key import CacheConst from .operators import ColumnOperators @@ -499,7 +500,7 @@ class DeferredLambdaElement(LambdaElement): class StatementLambdaElement( - roles.AllowsLambdaRole, LambdaElement, Executable + roles.AllowsLambdaRole, ExecutableStatement, LambdaElement ): """Represent a composable SQL statement as a :class:`_sql.LambdaElement`. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 6e62d30bc4..383f81b432 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -76,6 +76,7 @@ from .base import CompileState from .base import DedupeColumnCollection from .base import DialectKWArgs from .base import Executable +from .base import ExecutableStatement from .base import Generative from .base import HasCompileState from .base import HasMemoized @@ -290,7 +291,7 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement): raise NotImplementedError() -class ExecutableReturnsRows(Executable, ReturnsRows): +class ExecutableReturnsRows(ExecutableStatement, ReturnsRows): """base for executable statements that return rows.""" @@ -4593,7 +4594,7 @@ class CompoundSelect( + SupportsCloneAnnotations._clone_annotations_traverse_internals + HasCTE._has_ctes_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals - + Executable._executable_traverse_internals + + ExecutableStatement._executable_traverse_internals ) selects: List[SelectBase] @@ -5484,7 +5485,7 @@ class Select( + HasSuffixes._has_suffixes_traverse_internals + HasHints._has_hints_traverse_internals + SupportsCloneAnnotations._clone_annotations_traverse_internals - + Executable._executable_traverse_internals + + ExecutableStatement._executable_traverse_internals + DialectKWArgs._dialect_kwargs_traverse_internals ) @@ -7251,7 +7252,7 @@ class TextualSelect(SelectBase, ExecutableReturnsRows, Generative): ] + SupportsCloneAnnotations._clone_annotations_traverse_internals + HasCTE._has_ctes_traverse_internals - + Executable._executable_traverse_internals + + ExecutableStatement._executable_traverse_internals ) _is_textual = True diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 38f8e3e162..85dfda2818 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -953,6 +953,11 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): ): return COMPARE_FAILED + def visit_params( + self, attrname, left_parent, left, right_parent, right, **kw + ): + return left == right + def compare_expression_clauselist(self, left, right, **kw): if left.operator is right.operator: if operators.is_associative(left.operator): diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 1cd7097f2d..396b06b21d 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -410,6 +410,9 @@ class InternalTraversal(Enum): """Visit a list of inspectable objects which upon inspection are HasCacheKey objects.""" + dp_params = "PM" + """Visit the _params collection of ExecutableStatement""" + _TraverseInternalsType = List[Tuple[str, InternalTraversal]] """a structure that defines how a HasTraverseInternals should be diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index 3b98d60be7..b2fb34ad5c 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -38,6 +38,7 @@ from sqlalchemy.orm import with_loader_criteria from sqlalchemy.orm import with_polymorphic from sqlalchemy.sql import and_ from sqlalchemy.sql import sqltypes +from sqlalchemy.sql import visitors from sqlalchemy.sql.selectable import Join as core_join from sqlalchemy.sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL @@ -367,6 +368,60 @@ class SelectableTest(QueryTest, AssertsCompiledSQL): checkparams={"param_1": 6, "param_2": 5}, ) + @testing.variation("use_get_params", [True, False]) + def test_annotated_cte_params_traverse(self, use_get_params): + """test #12915 + + Tests the original issue in #12915 which was a specific issue + involving cloned_traverse with Annotated subclasses, where traversal + would not properly cover a CTE's self-referential structure. + + This case still does not work in the general ORM case, so the + implementation of .params() was changed to not rely upon + cloned_traversal. + + """ + User = self.classes.User + ids_param = bindparam("ids") + cte = select(User).where(User.id == ids_param).cte("cte") + ca = cte._annotate({"foo": "bar"}) + stmt = select(ca) + if use_get_params: + stmt = stmt.params(ids=17) + else: + # test without using params(), in case the implementation + # for params() changes we still want to test cloned_traverse + def visit_bindparam(bind): + if bind.key == "ids": + bind.value = 17 + bind.required = False + + stmt = visitors.cloned_traverse( + stmt, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, + ) + self.assert_compile( + stmt, + "WITH cte AS (SELECT users.id AS id, users.name AS name " + "FROM users WHERE users.id = :ids) " + "SELECT cte.id, cte.name FROM cte", + checkparams={"ids": 17}, + ) + + def test_orm_cte_with_params(self, connection): + """test for #12915's new implementation""" + User = self.classes.User + ids_param = bindparam("ids") + cte = select(User).where(User.id == ids_param).cte("cte") + stmt = select(aliased(User, cte.alias("a1"), adapt_on_names=True)) + + res = connection.execute(stmt, {"ids": 7}).all() + eq_(res, [(7, "jack")]) + with Session(connection) as s: + res = s.scalars(stmt, {"ids": 7}).all() + eq_(res, [User(id=7, name="jack")]) + class PropagateAttrsTest(QueryTest): __backend__ = True diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index ff7fd782ab..fbc6db81c5 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -2353,29 +2353,37 @@ class TestCacheKeyUtil(fixtures.TestBase): eq_( re.compile(r"[\n\s]+", re.M).sub( " ", - str(CacheKey(key=((1, (2, 7, 4), 5),), bindparams=[])), + str( + CacheKey( + key=((1, (2, 7, 4), 5),), bindparams=[], params={} + ) + ), ), "CacheKey(key=( ( 1, ( 2, 7, 4, ), 5, ), ),)", ) def test_nested_tuple_difference(self): """Test difference detection in nested tuples""" - k1 = CacheKey(key=((1, (2, 3, 4), 5),), bindparams=[]) - k2 = CacheKey(key=((1, (2, 7, 4), 5),), bindparams=[]) + k1 = CacheKey(key=((1, (2, 3, 4), 5),), bindparams=[], params={}) + k2 = CacheKey(key=((1, (2, 7, 4), 5),), bindparams=[], params={}) eq_(list(k1._whats_different(k2)), ["key[0][1][1]: 3 != 7"]) def test_deeply_nested_tuple_difference(self): """Test difference detection in deeply nested tuples""" - k1 = CacheKey(key=((1, (2, (3, 4, 5), 6), 7),), bindparams=[]) - k2 = CacheKey(key=((1, (2, (3, 9, 5), 6), 7),), bindparams=[]) + k1 = CacheKey( + key=((1, (2, (3, 4, 5), 6), 7),), bindparams=[], params={} + ) + k2 = CacheKey( + key=((1, (2, (3, 9, 5), 6), 7),), bindparams=[], params={} + ) eq_(list(k1._whats_different(k2)), ["key[0][1][1][1]: 4 != 9"]) def test_multiple_differences_nested(self): """Test detection of multiple differences in nested structure""" - k1 = CacheKey(key=((1, (2, 3), 4),), bindparams=[]) - k2 = CacheKey(key=((1, (5, 7), 4),), bindparams=[]) + k1 = CacheKey(key=((1, (2, 3), 4),), bindparams=[], params={}) + k2 = CacheKey(key=((1, (5, 7), 4),), bindparams=[], params={}) eq_( list(k1._whats_different(k2)), @@ -2384,21 +2392,29 @@ class TestCacheKeyUtil(fixtures.TestBase): def test_diff_method(self): """Test the _diff() method that returns a comma-separated string""" - k1 = CacheKey(key=((1, (2, 3)),), bindparams=[]) - k2 = CacheKey(key=((1, (5, 7)),), bindparams=[]) + k1 = CacheKey(key=((1, (2, 3)),), bindparams=[], params={}) + k2 = CacheKey(key=((1, (5, 7)),), bindparams=[], params={}) eq_(k1._diff(k2), "key[0][1][0]: 2 != 5, key[0][1][1]: 3 != 7") def test_with_string_differences(self): """Test detection of string differences""" - k1 = CacheKey(key=(("name", ("x", "value")),), bindparams=[]) - k2 = CacheKey(key=(("name", ("y", "value")),), bindparams=[]) + k1 = CacheKey( + key=(("name", ("x", "value")),), bindparams=[], params={} + ) + k2 = CacheKey( + key=(("name", ("y", "value")),), bindparams=[], params={} + ) eq_(list(k1._whats_different(k2)), ["key[0][1][0]: x != y"]) def test_with_mixed_types(self): """Test detection of differences with mixed types""" - k1 = CacheKey(key=(("id", 1, ("nested", 100)),), bindparams=[]) - k2 = CacheKey(key=(("id", 1, ("nested", 200)),), bindparams=[]) + k1 = CacheKey( + key=(("id", 1, ("nested", 100)),), bindparams=[], params={} + ) + k2 = CacheKey( + key=(("id", 1, ("nested", 200)),), bindparams=[], params={} + ) eq_(list(k1._whats_different(k2)), ["key[0][2][1]: 100 != 200"]) diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index d044d8b57f..c866df3c05 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -46,6 +46,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not +from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.testing.schema import eq_clause_element A = B = t1 = t2 = t3 = table1 = table2 = table3 = table4 = None @@ -898,13 +899,17 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): params = {"xb": 42, "yb": 33} sel = select(Y).select_from(jj).params(params) + cc = sel._generate_cache_key() + eq_( [ - eq_clause_element(bindparam("yb", value=33)), - eq_clause_element(bindparam("xb", value=42)), + eq_clause_element(bindparam("yb", None, Integer)), + eq_clause_element(bindparam("xb", None, Integer)), ], - sel._generate_cache_key()[1], + cc[1], ) + eq_(cc[2], {"xb": 42, "yb": 33}) + eq_(sel.compile().params, params) def test_dont_traverse_immutables(self): meta = MetaData() @@ -960,7 +965,12 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): stmt._generate_cache_key()[1], ) - stmt = select(b).where(criteria.params({param_key: "some other data"})) + with expect_deprecated( + r"The params\(\) and unique_params\(\) methods on non-statement" + ): + stmt = select(b).where( + criteria.params({param_key: "some other data"}) + ) self.assert_compile( stmt, "SELECT b.id, b.data FROM b, (SELECT b.id AS id " @@ -1004,7 +1014,11 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): Ps, join(pe_s, s_s, and_(pe_s.c.c == s_s.c.c, pe_s.c.p == s_s.c.p)), and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p), - ).params(params) + ) + with expect_deprecated( + r"The params\(\) and unique_params\(\) methods on non-statement" + ): + jj = jj.params(params) eq_( [ @@ -1039,11 +1053,13 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): pe_s = select(Pe).where(Pe.c.pid == bindparam("pid")).alias("pe_s") s_s = select(S).where(S.c.sid == bindparam("sid")).alias("s_s") - jj = ( - join(Ps, pe_s, and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p)) - .join(s_s, and_(Ps.c.c == s_s.c.c, Ps.c.p == s_s.c.p)) - .params(params) + jj = join(Ps, pe_s, and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p)).join( + s_s, and_(Ps.c.c == s_s.c.c, Ps.c.p == s_s.c.p) ) + with expect_deprecated( + r"The params\(\) and unique_params\(\) methods on non-statement" + ): + jj = jj.params(params) eq_( [ diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index c64aec7064..1aee4300ab 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -59,6 +59,7 @@ from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.testing.assertions import expect_raises_message from sqlalchemy.testing.assertions import expect_warnings from sqlalchemy.testing.engines import all_dialects @@ -691,12 +692,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "WHERE users.id > anon_1.z", ) - s = select(users).where( - users.c.id.between( - calculate.alias("c1").unique_params(x=17, y=45).c.z, - calculate.alias("c2").unique_params(x=5, y=12).c.z, - ), - ) + with expect_deprecated( + r"The params\(\) and unique_params\(\) methods on non-statement" + ): + s = select(users).where( + users.c.id.between( + calculate.alias("c1").unique_params(x=17, y=45).c.z, + calculate.alias("c2").unique_params(x=5, y=12).c.z, + ), + ) self.assert_compile( s, diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index c4c28b4f46..191ad66369 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -3147,7 +3147,7 @@ class AnnotationsTest(fixtures.TestBase): """ user = Table("user", MetaData(), Column("id", Integer)) - ids_param = bindparam("ids") + ids_param = bindparam("ids", -1) cte = select(user).where(user.c.id == ids_param).cte("cte") @@ -3157,6 +3157,8 @@ class AnnotationsTest(fixtures.TestBase): if use_get_params: stmt = stmt.params(ids=17) + exp = -1 + eq_(stmt._generate_cache_key()[2], {"ids": 17}) else: # test without using params(), as the implementation # for params() will be changing @@ -3170,13 +3172,16 @@ class AnnotationsTest(fixtures.TestBase): {"maintain_key": True, "detect_subquery_cols": True}, {"bindparam": visit_bindparam}, ) + exp = 17 + eq_(stmt._generate_cache_key()[2], None) eq_( stmt.selected_columns.id.table.element._where_criteria[ 0 ].right.value, - 17, + exp, ) + eq_(stmt.compile().params, {"ids": 17}) def test_basic_attrs(self): t = Table( diff --git a/test/sql/test_statement_params.py b/test/sql/test_statement_params.py new file mode 100644 index 0000000000..d4f2bf4a18 --- /dev/null +++ b/test/sql/test_statement_params.py @@ -0,0 +1,199 @@ +import random + +from sqlalchemy import testing +from sqlalchemy.schema import Column +from sqlalchemy.sql import bindparam +from sqlalchemy.sql import column +from sqlalchemy.sql import dml +from sqlalchemy.sql import func +from sqlalchemy.sql import select +from sqlalchemy.sql import text +from sqlalchemy.sql.base import ExecutableStatement +from sqlalchemy.sql.elements import literal +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_deprecated +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_not +from sqlalchemy.testing import ne_ +from sqlalchemy.testing.schema import Table +from sqlalchemy.types import Integer +from sqlalchemy.types import Text +from sqlalchemy.util.langhelpers import class_hierarchy + + +class BasicTests(fixtures.TestBase): + def _all_subclasses(self, cls_): + return dict.fromkeys( + s + for s in class_hierarchy(cls_) + # class_hierarchy may return values that + # aren't subclasses of cls + if issubclass(s, cls_) + ) + + @staticmethod + def _relevant_impls(): + return ( + text("select 1 + 2"), + text("select 42 as q").columns(column("q", Integer)), + func.max(42), + select(1, 2).union(select(3, 4)), + select(1, 2), + ) + + def test_params_impl(self): + exclude = (dml.UpdateBase,) + visit_names = set() + for cls_ in self._all_subclasses(ExecutableStatement): + if not issubclass(cls_, exclude): + if "__visit_name__" in cls_.__dict__: + visit_names.add(cls_.__visit_name__) + eq_(cls_.params, ExecutableStatement.params, cls_) + else: + ne_(cls_.params, ExecutableStatement.params, cls_) + for other in exclude: + if issubclass(cls_, other): + eq_(cls_.params, other.params, cls_) + break + else: + assert False + + extra = {"orm_from_statement"} + eq_( + visit_names - extra, + {i.__visit_name__ for i in self._relevant_impls()}, + ) + + @testing.combinations(*_relevant_impls()) + def test_compile_params(self, impl): + new = impl.params(foo=5, bar=10) + is_not(new, impl) + eq_(impl.compile()._collected_params, {}) + eq_(new.compile()._collected_params, {"foo": 5, "bar": 10}) + eq_(new._generate_cache_key()[2], {"foo": 5, "bar": 10}) + + +class CacheTests(fixtures.TablesTest): + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table("a", metadata, Column("data", Integer)) + Table("b", metadata, Column("data", Text)) + + @classmethod + def insert_data(cls, connection): + connection.execute( + cls.tables.a.insert(), + [{"data": i} for i in range(1, 11)], + ) + connection.execute( + cls.tables.b.insert(), + [{"data": "row %d" % i} for i in range(1, 11)], + ) + + def test_plain_select(self, connection): + a = self.tables.a + + cs = connection.scalars + + for _ in range(3): + x1 = random.randint(1, 10) + + eq_(cs(select(a).where(a.c.data == x1)).all(), [x1]) + stmt = select(a).where(a.c.data == bindparam("x", x1)) + eq_(cs(stmt).all(), [x1]) + + x1 = random.randint(1, 10) + eq_(cs(stmt.params({"x": x1})).all(), [x1]) + + x1 = random.randint(1, 10) + eq_(cs(stmt, {"x": x1}).all(), [x1]) + + x1 = random.randint(1, 10) + x2 = random.randint(1, 10) + eq_(cs(stmt.params({"x": x1}), {"x": x2}).all(), [x2]) + + stmt2 = stmt.params(x=6).subquery().select() + eq_(cs(stmt2).all(), [6]) + eq_(cs(stmt2.params({"x": 2})).all(), [2]) + + with expect_deprecated( + r"The params\(\) and unique_params\(\) " + "methods on non-statement" + ): + # NOTE: can't mix and match the two params styles here + stmt3 = stmt.params(x=6).subquery().params(x=8).select() + eq_(cs(stmt3).all(), [6]) + eq_(cs(stmt3.params({"x": 9})).all(), [9]) + + def test_union(self, connection): + a = self.tables.a + + cs = connection.scalars + for _ in range(3): + x1 = random.randint(1, 10) + x2 = random.randint(1, 10) + + eq_( + cs( + select(a) + .where(a.c.data == x1) + .union_all(select(a).where(a.c.data == x2)) + .order_by(a.c.data) + ).all(), + sorted([x1, x2]), + ) + + x1 = random.randint(1, 10) + x2 = random.randint(1, 10) + stmt = ( + select(a, literal(1).label("ord")) + .where(a.c.data == bindparam("x", x1)) + .union_all( + select(a, literal(2)).where(a.c.data == bindparam("y", x2)) + ) + .order_by("ord") + ) + eq_(cs(stmt).all(), [x1, x2]) + + x1a = random.randint(1, 10) + eq_(cs(stmt.params({"x": x1a})).all(), [x1a, x2]) + + x2 = random.randint(1, 10) + eq_(cs(stmt, {"y": x2}).all(), [x1, x2]) + + x1 = random.randint(1, 10) + x2 = random.randint(1, 10) + eq_(cs(stmt.params({"x": x1}), {"y": x2}).all(), [x1, x2]) + + x1 = random.randint(1, 10) + x2 = random.randint(1, 10) + stmt2 = ( + stmt.params(x=x1) + .subquery() + .select() + .params(y=x2) + .order_by("ord") + ) + eq_(cs(stmt2).all(), [x1, x2]) + eq_(cs(stmt2.params({"x": x1}).params({"y": x2})).all(), [x1, x2]) + + def test_text(self, connection): + a = self.tables.a + + cs = connection.scalars + + for _ in range(3): + x0 = random.randint(1, 10) + stmt = text("select data from a where data = :x").params(x=x0) + eq_(cs(stmt).all(), [x0]) + + x1 = random.randint(1, 10) + eq_(cs(stmt.params({"x": x1})).all(), [x1]) + + x2 = random.randint(1, 10) + stmt2 = stmt.columns(a.c.data).params(x=x2) + eq_(cs(stmt2).all(), [x2]) + eq_(cs(stmt2, {"x": 1}).all(), [1]) + eq_(cs(stmt2.params(x=1)).all(), [1])