]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve params implementation
authorFederico Caselli <cfederico87@gmail.com>
Thu, 16 Oct 2025 18:59:30 +0000 (20:59 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 10 Nov 2025 21:50:54 +0000 (16:50 -0500)
Added new implementation for the :meth:`.Select.params` method and that of
similar statements, via a new statement-only
:meth:`.ExecutableStatement.params` method which works more efficiently and
correctly than the previous implementations available from
:class:`.ClauseElement`, by assocating the given parameter dictionary with
the statement overall rather than cloning the statement and rewriting its
bound parameters.  The :meth:`_sql.ClauseElement.params` and
:meth:`_sql.ClauseElement.unique_params` methods, when called on an object
that does not implement :class:`.ExecutableStatement`, will continue to
work the old way of cloning the object, and will emit a deprecation
warning.    This issue both resolves the architectural / performance
concerns of :ticket:`7066` and also provides correct ORM compatibility for
functions like :func:`_orm.aliased`, reported by :ticket:`12915`.

Fixes: #7066
Change-Id: I6543c7d0f4da3232b3641fb172c24c446f02f52a

26 files changed:
doc/build/changelog/migration_21.rst
doc/build/changelog/unreleased_21/7066.rst [new file with mode: 0644]
doc/build/core/selectable.rst
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/_util_cy.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/cache_key.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/sql/lambdas.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/visitors.py
test/orm/test_core_compilation.py
test/sql/test_compare.py
test/sql/test_external_traversal.py
test/sql/test_functions.py
test/sql/test_selectable.py
test/sql/test_statement_params.py [new file with mode: 0644]

index 53d5c90ba12fa74ccae48ec36a19ed8a2193cf3c..8dd006798f425ceef6432a8cc38402a734a60606 100644 (file)
@@ -497,7 +497,6 @@ E.g.::
 New Features and Improvements - Core
 =====================================
 
-
 .. _change_10635:
 
 ``Row`` now represents individual column types directly without ``Tuple``
@@ -618,6 +617,51 @@ not the database portion::
 
 :ticket:`11234`
 
+.. _change_7066:
+
+Improved ``params()`` implementation for executable statements
+--------------------------------------------------------------
+
+The :meth:`_sql.ClauseElement.params` and :meth:`_sql.ClauseElement.unique_params`
+methods have been deprecated in favor of a new implementation on executable
+statements that provides improved performance and better integration with
+ORM-enabled statements.
+
+Executable statement objects like :class:`_sql.Select`, :class:`_sql.CompoundSelect`,
+and :class:`_sql.TextClause` now provide an improved :meth:`_sql.ExecutableStatement.params`
+method that avoids a full cloned traversal of the statement tree. Instead, parameters
+are stored directly on the statement object and efficiently merged during compilation
+and/or cache key traversal.
+
+The new implementation provides several benefits:
+
+* **Better performance** - Parameters are stored in a simple dictionary rather than
+  requiring a full statement tree traversal with cloning
+* **Proper caching integration** - Parameters are correctly integrated into SQLAlchemy's
+  cache key system via ``_generate_cache_key()``
+* **ORM statement compatibility** - Works correctly with ORM-enabled statements, including
+  ORM entities used with :func:`_orm.aliased`, subqueries, CTEs, etc.
+
+Use of :meth:`_sql.ExecutableStatement.params` is unchanged, provided the given
+object is a statement object such as :func:`_sql.select`::
+
+    stmt = select(table).where(table.c.data == bindparam("x"))
+
+    # Execute with parameter value
+    result = connection.execute(stmt.params(x=5))
+
+    # Can be chained and used in subqueries
+    stmt2 = stmt.params(x=6).subquery().select()
+    result = connection.execute(stmt2.params(x=7))  # Uses x=7
+
+The deprecated :meth:`_sql.ClauseElement.params` and :meth:`_sql.ClauseElement.unique_params`
+methods on non-executable elements like :class:`_sql.ColumnElement` and general
+:class:`_sql.ClauseElement` instances will continue to work during the deprecation
+period but will emit deprecation warnings.
+
+:ticket:`7066`
+
+
 .. _change_4950:
 
 CREATE TABLE AS SELECT Support
diff --git a/doc/build/changelog/unreleased_21/7066.rst b/doc/build/changelog/unreleased_21/7066.rst
new file mode 100644 (file)
index 0000000..27281b6
--- /dev/null
@@ -0,0 +1,21 @@
+.. change::
+    :tags: change, sql
+    :tickets: 7066, 12915
+
+    Added new implementation for the :meth:`.Select.params` method and that of
+    similar statements, via a new statement-only
+    :meth:`.ExecutableStatement.params` method which works more efficiently and
+    correctly than the previous implementations available from
+    :class:`.ClauseElement`, by assocating the given parameter dictionary with
+    the statement overall rather than cloning the statement and rewriting its
+    bound parameters.  The :meth:`_sql.ClauseElement.params` and
+    :meth:`_sql.ClauseElement.unique_params` methods, when called on an object
+    that does not implement :class:`.ExecutableStatement`, will continue to
+    work the old way of cloning the object, and will emit a deprecation
+    warning.    This issue both resolves the architectural / performance
+    concerns of :ticket:`7066` and also provides correct ORM compatibility for
+    functions like :func:`_orm.aliased`, reported by :ticket:`12915`.
+
+    .. seealso::
+
+        :ref:`change_7066`
index d7c2b56c8b2e7c36c807e8e7d9723fc0efe1df89..38d47aa657387026626aa99eea0c4305d1b8e996 100644 (file)
@@ -89,6 +89,9 @@ The classes here are generated using the constructors listed at
 .. autoclass:: Executable
    :members:
 
+.. autoclass:: ExecutableStatement
+   :members:
+
 .. autoclass:: Exists
    :members:
 
index 0d8bf6f08b2b9d3446904b4c740f26563b88902e..57cb0fc2a95e06ed30dd84e76d6fd6804fca85ff 100644 (file)
@@ -1641,13 +1641,15 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
             "compiled_cache", self.engine._compiled_cache
         )
 
-        compiled_sql, extracted_params, cache_hit = elem._compile_w_cache(
-            dialect=dialect,
-            compiled_cache=compiled_cache,
-            column_keys=keys,
-            for_executemany=for_executemany,
-            schema_translate_map=schema_translate_map,
-            linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
+        compiled_sql, extracted_params, param_dict, cache_hit = (
+            elem._compile_w_cache(
+                dialect=dialect,
+                compiled_cache=compiled_cache,
+                column_keys=keys,
+                for_executemany=for_executemany,
+                schema_translate_map=schema_translate_map,
+                linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
+            )
         )
         ret = self._execute_context(
             dialect,
@@ -1660,6 +1662,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
             elem,
             extracted_params,
             cache_hit=cache_hit,
+            param_dict=param_dict,
         )
         if has_events:
             self.dispatch.after_execute(
index 367455d5e2fdbde57d48df6d4c76fab9ca61601b..1ffec4092d5be2c3abe9cf1cda807d6494892cc1 100644 (file)
@@ -1329,6 +1329,7 @@ class DefaultExecutionContext(ExecutionContext):
         invoked_statement: Executable,
         extracted_parameters: Optional[Sequence[BindParameter[Any]]],
         cache_hit: CacheStats = CacheStats.CACHING_DISABLED,
+        param_dict: _CoreSingleExecuteParams | None = None,
     ) -> ExecutionContext:
         """Initialize execution context for a Compiled construct."""
 
@@ -1417,6 +1418,7 @@ class DefaultExecutionContext(ExecutionContext):
                 compiled.construct_params(
                     extracted_parameters=extracted_parameters,
                     escape_names=False,
+                    _collected_params=param_dict,
                 )
             ]
         else:
@@ -1426,6 +1428,7 @@ class DefaultExecutionContext(ExecutionContext):
                     escape_names=False,
                     _group_number=grp,
                     extracted_parameters=extracted_parameters,
+                    _collected_params=param_dict,
                 )
                 for grp, m in enumerate(parameters)
             ]
index 2dd6e950a7100192d49859e73601044f600cd39f..8f26eb2c5d4160371b536a89bbdc121a94afff1e 100644 (file)
@@ -54,6 +54,7 @@ from ..sql.base import _select_iterables
 from ..sql.base import CacheableOptions
 from ..sql.base import CompileState
 from ..sql.base import Executable
+from ..sql.base import ExecutableStatement
 from ..sql.base import Generative
 from ..sql.base import Options
 from ..sql.dml import UpdateBase
@@ -981,7 +982,7 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[Unpack[_Ts]]):
     _traverse_internals = [
         ("_raw_columns", InternalTraversal.dp_clauseelement_list),
         ("element", InternalTraversal.dp_clauseelement),
-    ] + Executable._executable_traverse_internals
+    ] + ExecutableStatement._executable_traverse_internals
 
     _cache_key_traversal = _traverse_internals + [
         ("_compile_options", InternalTraversal.dp_has_cache_key)
index 1231d166cde8be2fba6a7d47b5d95a46d3aa7b64..c49c0f6c3ad7ddd89807a7c43bc3143a990ef2e6 100644 (file)
@@ -11,6 +11,7 @@ from ._typing import ColumnExpressionArgument as ColumnExpressionArgument
 from ._typing import NotNullable as NotNullable
 from ._typing import Nullable as Nullable
 from .base import Executable as Executable
+from .base import ExecutableStatement as ExecutableStatement
 from .base import SyntaxExtension as SyntaxExtension
 from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS
 from .compiler import FROM_LINTING as FROM_LINTING
index 8d4ef542b97c04647eb7d558ae387583040e3b39..ddb8680196df2325b5db17ae3dbe8e60c4fd81b6 100644 (file)
@@ -15,6 +15,7 @@ from typing import Union
 
 if TYPE_CHECKING:
     from .cache_key import CacheConst
+    from ..engine.interfaces import _CoreSingleExecuteParams
 
 # START GENERATED CYTHON IMPORT
 # This section is automatically generated by the script tools/cython_imports.py
@@ -67,13 +68,12 @@ class prefix_anon_map(Dict[str, str]):
         return value
 
 
+_AM_KEY = Union[int, str, "CacheConst"]
+_AM_VALUE = Union[int, Literal[True], "_CoreSingleExecuteParams"]
+
+
 @cython.cclass
-class anon_map(
-    Dict[
-        Union[int, str, "Literal[CacheConst.NO_CACHE]"],
-        Union[int, Literal[True]],
-    ]
-):
+class anon_map(Dict[_AM_KEY, _AM_VALUE]):
     """A map that creates new keys for missing key access.
 
     Produces an incrementing sequence given a series of unique keys.
@@ -96,9 +96,7 @@ class anon_map(
 
     @cython.cfunc  # type:ignore[misc]
     @cython.inline  # type:ignore[misc]
-    def _add_missing(
-        self: anon_map, key: Union[int, str, "Literal[CacheConst.NO_CACHE]"], /
-    ) -> int:
+    def _add_missing(self: anon_map, key: _AM_KEY, /) -> int:
         val: int = self._index
         self._index += 1
         self_dict: dict = self  # type: ignore[type-arg]
@@ -116,11 +114,7 @@ class anon_map(
 
     if cython.compiled:
 
-        def __getitem__(
-            self: anon_map,
-            key: Union[int, str, "Literal[CacheConst.NO_CACHE]"],
-            /,
-        ) -> Union[int, Literal[True]]:
+        def __getitem__(self: anon_map, key: _AM_KEY, /) -> _AM_VALUE:
             self_dict: dict = self  # type: ignore[type-arg]
 
             if key in self_dict:
@@ -128,7 +122,5 @@ class anon_map(
             else:
                 return self._add_missing(key)  # type:ignore[no-any-return]
 
-    def __missing__(
-        self: anon_map, key: Union[int, str, "Literal[CacheConst.NO_CACHE]"], /
-    ) -> int:
+    def __missing__(self: anon_map, key: _AM_KEY, /) -> int:
         return self._add_missing(key)  # type:ignore[no-any-return]
index a07b1204b91aa43862953ea1289dd7b12ebf0845..67eb44fc8dea8977f3135d60847e2838c9cbc0bb 100644 (file)
@@ -56,6 +56,7 @@ from .visitors import InternalTraversal
 from .. import event
 from .. import exc
 from .. import util
+from ..util import EMPTY_DICT
 from ..util import HasMemoized as HasMemoized
 from ..util import hybridmethod
 from ..util.typing import Self
@@ -92,6 +93,7 @@ if TYPE_CHECKING:
     from ..engine import Connection
     from ..engine import CursorResult
     from ..engine.interfaces import _CoreMultiExecuteParams
+    from ..engine.interfaces import _CoreSingleExecuteParams
     from ..engine.interfaces import _ExecuteOptions
     from ..engine.interfaces import _ImmutableExecuteOptions
     from ..engine.interfaces import CacheStats
@@ -1010,7 +1012,7 @@ class CacheableOptions(Options, HasCacheKey):
 
     @hybridmethod
     def _generate_cache_key(self) -> Optional[CacheKey]:
-        return HasCacheKey._generate_cache_key_for_object(self)
+        return HasCacheKey._generate_cache_key(self)
 
 
 class ExecutableOption(HasCopyInternals):
@@ -1287,8 +1289,11 @@ class Executable(roles.StatementRole):
             for_executemany: bool = False,
             schema_translate_map: Optional[SchemaTranslateMapType] = None,
             **kw: Any,
-        ) -> Tuple[
-            Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats
+        ) -> tuple[
+            Compiled,
+            Sequence[BindParameter[Any]] | None,
+            _CoreSingleExecuteParams | None,
+            CacheStats,
         ]: ...
 
         def _execute_on_connection(
@@ -1540,6 +1545,50 @@ class Executable(roles.StatementRole):
         return self._execution_options
 
 
+class ExecutableStatement(Executable):
+    """Executable subclass that implements a lightweight version of ``params``
+    that avoids a full cloned traverse.
+
+    .. versionadded:: 2.1
+
+    """
+
+    _params: util.immutabledict[str, Any] = EMPTY_DICT
+
+    _executable_traverse_internals = (
+        Executable._executable_traverse_internals
+        + [("_params", InternalTraversal.dp_params)]
+    )
+
+    @_generative
+    def params(
+        self,
+        __optionaldict: _CoreSingleExecuteParams | None = None,
+        /,
+        **kwargs: Any,
+    ) -> Self:
+        """Return a copy with the provided bindparam values.
+
+        Returns a copy of this Executable with bindparam values set
+        to the given dictionary::
+
+          >>> clause = column("x") + bindparam("foo")
+          >>> print(clause.compile().params)
+          {'foo': None}
+          >>> print(clause.params({"foo": 7}).compile().params)
+          {'foo': 7}
+
+        """
+        if __optionaldict:
+            kwargs.update(__optionaldict)
+        self._params = (
+            util.immutabledict(kwargs)
+            if not self._params
+            else self._params | kwargs
+        )
+        return self
+
+
 class SchemaEventTarget(event.EventTarget):
     """Base class for elements that are the targets of :class:`.DDLEvents`
     events.
index f44ca268863e5df2cb609fe75fd7bcc90240ecd6..e35c201f21efdc39c8b056a95d63464d6ba00bcc 100644 (file)
@@ -13,6 +13,7 @@ import typing
 from typing import Any
 from typing import Callable
 from typing import Dict
+from typing import Final
 from typing import Iterable
 from typing import Iterator
 from typing import List
@@ -50,9 +51,10 @@ class _CacheKeyTraversalDispatchType(Protocol):
 
 class CacheConst(enum.Enum):
     NO_CACHE = 0
+    PARAMS = 1
 
 
-NO_CACHE = CacheConst.NO_CACHE
+NO_CACHE: Final = CacheConst.NO_CACHE
 
 
 _CacheKeyTraversalType = Union[
@@ -384,21 +386,11 @@ class HasCacheKey:
             return None
         else:
             assert key is not None
-            return CacheKey(key, bindparams)
-
-    @classmethod
-    def _generate_cache_key_for_object(
-        cls, obj: HasCacheKey
-    ) -> Optional[CacheKey]:
-        bindparams: List[BindParameter[Any]] = []
-
-        _anon_map = anon_map()
-        key = obj._gen_cache_key(_anon_map, bindparams)
-        if NO_CACHE in _anon_map:
-            return None
-        else:
-            assert key is not None
-            return CacheKey(key, bindparams)
+            return CacheKey(
+                key,
+                bindparams,
+                _anon_map.get(CacheConst.PARAMS),  # type: ignore[arg-type]
+            )
 
 
 class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey):
@@ -432,6 +424,7 @@ class CacheKey(NamedTuple):
 
     key: Tuple[Any, ...]
     bindparams: Sequence[BindParameter[Any]]
+    params: _CoreSingleExecuteParams | None
 
     # can't set __hash__ attribute because it interferes
     # with namedtuple
@@ -485,8 +478,8 @@ class CacheKey(NamedTuple):
 
     @classmethod
     def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str:
-        ck1 = CacheKey(left, [])
-        ck2 = CacheKey(right, [])
+        ck1 = CacheKey(left, [], None)
+        ck2 = CacheKey(right, [], None)
         return ck1._diff(ck2)
 
     def _whats_different(self, other: CacheKey) -> Iterator[str]:
@@ -1053,5 +1046,21 @@ class _CacheKeyTraversal(HasTraversalDispatch):
         anon_map[NO_CACHE] = True
         return ()
 
+    def visit_params(
+        self,
+        attrname: str,
+        obj: Any,
+        parent: Any,
+        anon_map: anon_map,
+        bindparams: List[BindParameter[Any]],
+    ) -> Tuple[Any, ...]:
+        if obj:
+            if CacheConst.PARAMS in anon_map:
+                to_set = anon_map[CacheConst.PARAMS] | obj
+            else:
+                to_set = obj
+            anon_map[CacheConst.PARAMS] = to_set
+        return ()
+
 
 _cache_key_traversal_visitor = _CacheKeyTraversal()
index c666fab02541a4f1119f37a35c09f16454edad07..ee13cf8da08d35365a41863279c2f7336c5440c5 100644 (file)
@@ -40,6 +40,7 @@ from typing import Callable
 from typing import cast
 from typing import ClassVar
 from typing import Dict
+from typing import Final
 from typing import FrozenSet
 from typing import Iterable
 from typing import Iterator
@@ -93,6 +94,7 @@ if typing.TYPE_CHECKING:
     from .base import _AmbiguousTableNameMap
     from .base import CompileState
     from .base import Executable
+    from .base import ExecutableStatement
     from .cache_key import CacheKey
     from .ddl import CreateTableAs
     from .ddl import ExecutableDDLElement
@@ -1390,6 +1392,8 @@ class SQLCompiler(Compiled):
     _positional_pattern = re.compile(
         f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
     )
+    _collect_params: Final[bool]
+    _collected_params: util.immutabledict[str, Any]
 
     @classmethod
     def _init_compiler_cls(cls):
@@ -1489,6 +1493,11 @@ class SQLCompiler(Compiled):
         # dialect.label_length or dialect.max_identifier_length
         self.truncated_names: Dict[Tuple[str, str], str] = {}
         self._truncated_counters: Dict[str, int] = {}
+        if not cache_key:
+            self._collect_params = True
+            self._collected_params = util.EMPTY_DICT
+        else:
+            self._collect_params = False  # type: ignore[misc]
 
         Compiled.__init__(self, dialect, statement, **kwargs)
 
@@ -1627,6 +1636,13 @@ class SQLCompiler(Compiled):
     def _global_attributes(self) -> Dict[Any, Any]:
         return {}
 
+    def _add_to_params(self, item: ExecutableStatement) -> None:
+        # assumes that this is called before traversing the statement
+        # so the call happens outer to inner, meaning that existing params
+        # take precedence
+        if item._params:
+            self._collected_params = item._params | self._collected_params
+
     @util.memoized_instancemethod
     def _init_cte_state(self) -> MutableMapping[CTE, str]:
         """Initialize collections related to CTEs only if
@@ -1874,8 +1890,19 @@ class SQLCompiler(Compiled):
         _group_number: Optional[int] = None,
         _check: bool = True,
         _no_postcompile: bool = False,
+        _collected_params: _CoreSingleExecuteParams | None = None,
     ) -> _MutableCoreSingleExecuteParams:
         """return a dictionary of bind parameter keys and values"""
+        if _collected_params is not None:
+            assert not self._collect_params
+        elif self._collect_params:
+            _collected_params = self._collected_params
+
+        if _collected_params:
+            if not params:
+                params = _collected_params
+            else:
+                params = {**_collected_params, **params}
 
         if self._render_postcompile and not _no_postcompile:
             assert self._post_compile_expanded_state is not None
@@ -2704,6 +2731,9 @@ class SQLCompiler(Compiled):
         return text
 
     def visit_textclause(self, textclause, add_to_result_map=None, **kw):
+        if self._collect_params:
+            self._add_to_params(textclause)
+
         def do_bindparam(m):
             name = m.group(1)
             if name in textclause._bindparams:
@@ -2731,6 +2761,8 @@ class SQLCompiler(Compiled):
     def visit_textual_select(
         self, taf, compound_index=None, asfrom=False, **kw
     ):
+        if self._collect_params:
+            self._add_to_params(taf)
         toplevel = not self.stack
         entry = self._default_stack_entry if toplevel else self.stack[-1]
 
@@ -3026,6 +3058,8 @@ class SQLCompiler(Compiled):
         add_to_result_map: Optional[_ResultMapAppender] = None,
         **kwargs: Any,
     ) -> str:
+        if self._collect_params:
+            self._add_to_params(func)
         if add_to_result_map is not None:
             add_to_result_map(func.name, func.name, (func.name,), func.type)
 
@@ -3081,6 +3115,8 @@ class SQLCompiler(Compiled):
     def visit_compound_select(
         self, cs, asfrom=False, compound_index=None, **kwargs
     ):
+        if self._collect_params:
+            self._add_to_params(cs)
         toplevel = not self.stack
 
         compile_state = cs._compile_state_factory(cs, self, **kwargs)
@@ -4870,6 +4906,8 @@ class SQLCompiler(Compiled):
             "the translate_select_structure hook for structural "
             "translations of SELECT objects"
         )
+        if self._collect_params:
+            self._add_to_params(select_stmt)
 
         # initial setup of SELECT.  the compile_state_factory may now
         # be creating a totally different SELECT from the one that was
index 58a8c3c8e8c4b6499ce964bfef43f5a6ba31b3f6..be8781bb76eb0b07055ee644e7f9c9976f92b2b1 100644 (file)
@@ -55,6 +55,7 @@ if typing.TYPE_CHECKING:
     from .schema import Sequence as Sequence  # noqa: F401
     from .schema import Table
     from ..engine.base import Connection
+    from ..engine.interfaces import _CoreSingleExecuteParams
     from ..engine.interfaces import CacheStats
     from ..engine.interfaces import CompiledCacheType
     from ..engine.interfaces import Dialect
@@ -89,8 +90,11 @@ class BaseDDLElement(ClauseElement):
         for_executemany: bool = False,
         schema_translate_map: Optional[SchemaTranslateMapType] = None,
         **kw: Any,
-    ) -> Tuple[
-        Compiled, Optional[typing_Sequence[BindParameter[Any]]], CacheStats
+    ) -> tuple[
+        Compiled,
+        typing_Sequence[BindParameter[Any]] | None,
+        _CoreSingleExecuteParams | None,
+        CacheStats,
     ]:
         raise NotImplementedError()
 
index 54e95d20a6a3dc8cb1b8270bb5d7e1ae0da5fef8..e85e98a6f8f48d27c46a9e0e4e42cd5137aad4b4 100644 (file)
@@ -48,6 +48,7 @@ from .base import ColumnSet
 from .base import CompileState
 from .base import DialectKWArgs
 from .base import Executable
+from .base import ExecutableStatement
 from .base import Generative
 from .base import HasCompileState
 from .base import HasSyntaxExtensions
@@ -1249,7 +1250,7 @@ class Insert(ValuesBase, HasSyntaxExtensions[Literal["post_values"]]):
         ]
         + HasPrefixes._has_prefixes_traverse_internals
         + DialectKWArgs._dialect_kwargs_traverse_internals
-        + Executable._executable_traverse_internals
+        + ExecutableStatement._executable_traverse_internals
         + HasCTE._has_ctes_traverse_internals
     )
 
@@ -1614,7 +1615,7 @@ class Update(
         ]
         + HasPrefixes._has_prefixes_traverse_internals
         + DialectKWArgs._dialect_kwargs_traverse_internals
-        + Executable._executable_traverse_internals
+        + ExecutableStatement._executable_traverse_internals
         + HasCTE._has_ctes_traverse_internals
     )
 
@@ -1815,7 +1816,7 @@ class Delete(
         ]
         + HasPrefixes._has_prefixes_traverse_internals
         + DialectKWArgs._dialect_kwargs_traverse_internals
-        + Executable._executable_traverse_internals
+        + ExecutableStatement._executable_traverse_internals
         + HasCTE._has_ctes_traverse_internals
     )
 
index ddbfd00c69c6f6047b1677ee37f8a562a60a8479..7a640ccacc8a0b1188fda7382a523051d8b6a1c5 100644 (file)
@@ -58,6 +58,7 @@ from .base import _expand_cloned
 from .base import _generative
 from .base import _NoArg
 from .base import Executable
+from .base import ExecutableStatement
 from .base import Generative
 from .base import HasMemoized
 from .base import Immutable
@@ -77,6 +78,7 @@ from .visitors import Visitable
 from .. import exc
 from .. import inspection
 from .. import util
+from ..util import deprecated
 from ..util import HasMemoized_ro_memoized_attribute
 from ..util import TypingOnly
 from ..util.typing import Self
@@ -117,6 +119,7 @@ if typing.TYPE_CHECKING:
     from ..engine import Connection
     from ..engine import Dialect
     from ..engine.interfaces import _CoreMultiExecuteParams
+    from ..engine.interfaces import _CoreSingleExecuteParams
     from ..engine.interfaces import CacheStats
     from ..engine.interfaces import CompiledCacheType
     from ..engine.interfaces import CoreExecuteOptionsParameter
@@ -610,6 +613,12 @@ class ClauseElement(
         """
         return self._replace_params(False, __optionaldict, kwargs)
 
+    @deprecated(
+        "2.1",
+        "The params() and unique_params() methods on non-statement "
+        "ClauseElement objects are deprecated; params() is now limited to "
+        "statement level objects such as select(), insert(), union(), etc. ",
+    )
     def _replace_params(
         self,
         unique: bool,
@@ -691,8 +700,11 @@ class ClauseElement(
         for_executemany: bool = False,
         schema_translate_map: Optional[SchemaTranslateMapType] = None,
         **kw: Any,
-    ) -> typing_Tuple[
-        Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats
+    ) -> tuple[
+        Compiled,
+        Sequence[BindParameter[Any]] | None,
+        _CoreSingleExecuteParams | None,
+        CacheStats,
     ]:
         elem_cache_key: Optional[CacheKey]
 
@@ -706,7 +718,7 @@ class ClauseElement(
             if TYPE_CHECKING:
                 assert compiled_cache is not None
 
-            cache_key, extracted_params = elem_cache_key
+            cache_key, extracted_params, param_dict = elem_cache_key
             key = (
                 dialect,
                 cache_key,
@@ -726,19 +738,26 @@ class ClauseElement(
                     schema_translate_map=schema_translate_map,
                     **kw,
                 )
+                # ensure that params of the current statement are not
+                # left in the cache
+                assert not compiled_sql._collect_params  # type: ignore[attr-defined] # noqa: E501
                 compiled_cache[key] = compiled_sql
             else:
                 cache_hit = dialect.CACHE_HIT
         else:
+            param_dict = None
             extracted_params = None
             compiled_sql = self._compiler(
                 dialect,
-                cache_key=elem_cache_key,
+                cache_key=None,
                 column_keys=column_keys,
                 for_executemany=for_executemany,
                 schema_translate_map=schema_translate_map,
                 **kw,
             )
+            # here instead the params need to be extracted, since we don't
+            # have them otherwise
+            assert compiled_sql._collect_params  # type: ignore[attr-defined] # noqa: E501
 
             if not dialect._supports_statement_cache:
                 cache_hit = dialect.NO_DIALECT_SUPPORT
@@ -747,7 +766,7 @@ class ClauseElement(
             else:
                 cache_hit = dialect.NO_CACHE_KEY
 
-        return compiled_sql, extracted_params, cache_hit
+        return compiled_sql, extracted_params, param_dict, cache_hit
 
     def __invert__(self):
         # undocumented element currently used by the ORM for
@@ -2314,7 +2333,7 @@ class TextClause(
     roles.SelectStatementRole,
     roles.InElementRole,
     Generative,
-    Executable,
+    ExecutableStatement,
     DQLDMLClauseElement,
     roles.BinaryElementRole[Any],
     inspection.Inspectable["TextClause"],
@@ -2343,7 +2362,7 @@ class TextClause(
     _traverse_internals: _TraverseInternalsType = [
         ("_bindparams", InternalTraversal.dp_string_clauseelement_dict),
         ("text", InternalTraversal.dp_string),
-    ] + Executable._executable_traverse_internals
+    ] + ExecutableStatement._executable_traverse_internals
 
     _is_text_clause = True
 
index d85142ed90d8c7e96f8ef0dca2d2671b1289deec..60029dbf4fc791c586b0b0e5b70c5ddb4b6613e3 100644 (file)
@@ -66,6 +66,7 @@ from .base import _from_objects as _from_objects
 from .base import _select_iterables as _select_iterables
 from .base import ColumnCollection as ColumnCollection
 from .base import Executable as Executable
+from .base import ExecutableStatement as ExecutableStatement
 from .cache_key import CacheKey as CacheKey
 from .dml import Delete as Delete
 from .dml import Insert as Insert
index f4437bb980851bc11e3939e345c60d1d65238b04..8671986a6b2ad8f10813be7f72415cf27313c02f 100644 (file)
@@ -36,7 +36,7 @@ from . import util as sqlutil
 from ._typing import is_table_value_type
 from .base import _entity_namespace
 from .base import ColumnCollection
-from .base import Executable
+from .base import ExecutableStatement
 from .base import Generative
 from .base import HasMemoized
 from .elements import _type_from_args
@@ -114,7 +114,9 @@ def register_function(
     reg[identifier] = fn
 
 
-class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
+class FunctionElement(
+    ColumnElement[_T], ExecutableStatement, FromClause, Generative
+):
     """Base for SQL function-oriented constructs.
 
     This is a `generic type <https://peps.python.org/pep-0484/#generics>`_,
@@ -140,7 +142,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         ("clause_expr", InternalTraversal.dp_clauseelement),
         ("_with_ordinality", InternalTraversal.dp_boolean),
         ("_table_value_type", InternalTraversal.dp_has_cache_key),
-    ] + Executable._executable_traverse_internals
+    ] + ExecutableStatement._executable_traverse_internals
 
     packagenames: Tuple[str, ...] = ()
 
index 731f7e5932069d278a6b78b8ec86d2f3319be4a4..02fcd34131ea664608f77366f10745a553c4193b 100644 (file)
@@ -37,6 +37,7 @@ from . import schema
 from . import visitors
 from .base import _clone
 from .base import Executable
+from .base import ExecutableStatement
 from .base import Options
 from .cache_key import CacheConst
 from .operators import ColumnOperators
@@ -499,7 +500,7 @@ class DeferredLambdaElement(LambdaElement):
 
 
 class StatementLambdaElement(
-    roles.AllowsLambdaRole, LambdaElement, Executable
+    roles.AllowsLambdaRole, ExecutableStatement, LambdaElement
 ):
     """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
 
index 6e62d30bc49b5babcd8b195abd78bc912c3ae071..383f81b432155518b96048db930366026f951eac 100644 (file)
@@ -76,6 +76,7 @@ from .base import CompileState
 from .base import DedupeColumnCollection
 from .base import DialectKWArgs
 from .base import Executable
+from .base import ExecutableStatement
 from .base import Generative
 from .base import HasCompileState
 from .base import HasMemoized
@@ -290,7 +291,7 @@ class ReturnsRows(roles.ReturnsRowsRole, DQLDMLClauseElement):
         raise NotImplementedError()
 
 
-class ExecutableReturnsRows(Executable, ReturnsRows):
+class ExecutableReturnsRows(ExecutableStatement, ReturnsRows):
     """base for executable statements that return rows."""
 
 
@@ -4593,7 +4594,7 @@ class CompoundSelect(
         + SupportsCloneAnnotations._clone_annotations_traverse_internals
         + HasCTE._has_ctes_traverse_internals
         + DialectKWArgs._dialect_kwargs_traverse_internals
-        + Executable._executable_traverse_internals
+        + ExecutableStatement._executable_traverse_internals
     )
 
     selects: List[SelectBase]
@@ -5484,7 +5485,7 @@ class Select(
         + HasSuffixes._has_suffixes_traverse_internals
         + HasHints._has_hints_traverse_internals
         + SupportsCloneAnnotations._clone_annotations_traverse_internals
-        + Executable._executable_traverse_internals
+        + ExecutableStatement._executable_traverse_internals
         + DialectKWArgs._dialect_kwargs_traverse_internals
     )
 
@@ -7251,7 +7252,7 @@ class TextualSelect(SelectBase, ExecutableReturnsRows, Generative):
         ]
         + SupportsCloneAnnotations._clone_annotations_traverse_internals
         + HasCTE._has_ctes_traverse_internals
-        + Executable._executable_traverse_internals
+        + ExecutableStatement._executable_traverse_internals
     )
 
     _is_textual = True
index 38f8e3e162355017efaa43705c3baf0b1b2efdc4..85dfda281845963d77544c69d937eaceafec606d 100644 (file)
@@ -953,6 +953,11 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
                 ):
                     return COMPARE_FAILED
 
+    def visit_params(
+        self, attrname, left_parent, left, right_parent, right, **kw
+    ):
+        return left == right
+
     def compare_expression_clauselist(self, left, right, **kw):
         if left.operator is right.operator:
             if operators.is_associative(left.operator):
index 1cd7097f2d2a306337c034c6a006a4c78d12d6ac..396b06b21d9a0c27e2798e81e3d7e7f8b85ed34e 100644 (file)
@@ -410,6 +410,9 @@ class InternalTraversal(Enum):
     """Visit a list of inspectable objects which upon inspection are
     HasCacheKey objects."""
 
+    dp_params = "PM"
+    """Visit the _params collection of ExecutableStatement"""
+
 
 _TraverseInternalsType = List[Tuple[str, InternalTraversal]]
 """a structure that defines how a HasTraverseInternals should be
index 3b98d60be780c5e969026940b7e912aebd511cf6..b2fb34ad5c90c1a476c3fe527136c2b139badb78 100644 (file)
@@ -38,6 +38,7 @@ from sqlalchemy.orm import with_loader_criteria
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.sql import and_
 from sqlalchemy.sql import sqltypes
+from sqlalchemy.sql import visitors
 from sqlalchemy.sql.selectable import Join as core_join
 from sqlalchemy.sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
 from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
@@ -367,6 +368,60 @@ class SelectableTest(QueryTest, AssertsCompiledSQL):
             checkparams={"param_1": 6, "param_2": 5},
         )
 
+    @testing.variation("use_get_params", [True, False])
+    def test_annotated_cte_params_traverse(self, use_get_params):
+        """test #12915
+
+        Tests the original issue in #12915 which was a specific issue
+        involving cloned_traverse with Annotated subclasses, where traversal
+        would not properly cover a CTE's self-referential structure.
+
+        This case still does not work in the general ORM case, so the
+        implementation of .params() was changed to not rely upon
+        cloned_traversal.
+
+        """
+        User = self.classes.User
+        ids_param = bindparam("ids")
+        cte = select(User).where(User.id == ids_param).cte("cte")
+        ca = cte._annotate({"foo": "bar"})
+        stmt = select(ca)
+        if use_get_params:
+            stmt = stmt.params(ids=17)
+        else:
+            # test without using params(), in case the implementation
+            # for params() changes we still want to test cloned_traverse
+            def visit_bindparam(bind):
+                if bind.key == "ids":
+                    bind.value = 17
+                    bind.required = False
+
+            stmt = visitors.cloned_traverse(
+                stmt,
+                {"maintain_key": True, "detect_subquery_cols": True},
+                {"bindparam": visit_bindparam},
+            )
+        self.assert_compile(
+            stmt,
+            "WITH cte AS (SELECT users.id AS id, users.name AS name "
+            "FROM users WHERE users.id = :ids) "
+            "SELECT cte.id, cte.name FROM cte",
+            checkparams={"ids": 17},
+        )
+
+    def test_orm_cte_with_params(self, connection):
+        """test for #12915's new implementation"""
+        User = self.classes.User
+        ids_param = bindparam("ids")
+        cte = select(User).where(User.id == ids_param).cte("cte")
+        stmt = select(aliased(User, cte.alias("a1"), adapt_on_names=True))
+
+        res = connection.execute(stmt, {"ids": 7}).all()
+        eq_(res, [(7, "jack")])
+        with Session(connection) as s:
+            res = s.scalars(stmt, {"ids": 7}).all()
+        eq_(res, [User(id=7, name="jack")])
+
 
 class PropagateAttrsTest(QueryTest):
     __backend__ = True
index ff7fd782ab211735ad70743a4c4ae8b3ad1f7c15..fbc6db81c5eb3622d23b528478dc926715a04592 100644 (file)
@@ -2353,29 +2353,37 @@ class TestCacheKeyUtil(fixtures.TestBase):
         eq_(
             re.compile(r"[\n\s]+", re.M).sub(
                 " ",
-                str(CacheKey(key=((1, (2, 7, 4), 5),), bindparams=[])),
+                str(
+                    CacheKey(
+                        key=((1, (2, 7, 4), 5),), bindparams=[], params={}
+                    )
+                ),
             ),
             "CacheKey(key=( ( 1, ( 2, 7, 4, ), 5, ), ),)",
         )
 
     def test_nested_tuple_difference(self):
         """Test difference detection in nested tuples"""
-        k1 = CacheKey(key=((1, (2, 3, 4), 5),), bindparams=[])
-        k2 = CacheKey(key=((1, (2, 7, 4), 5),), bindparams=[])
+        k1 = CacheKey(key=((1, (2, 3, 4), 5),), bindparams=[], params={})
+        k2 = CacheKey(key=((1, (2, 7, 4), 5),), bindparams=[], params={})
 
         eq_(list(k1._whats_different(k2)), ["key[0][1][1]:  3 != 7"])
 
     def test_deeply_nested_tuple_difference(self):
         """Test difference detection in deeply nested tuples"""
-        k1 = CacheKey(key=((1, (2, (3, 4, 5), 6), 7),), bindparams=[])
-        k2 = CacheKey(key=((1, (2, (3, 9, 5), 6), 7),), bindparams=[])
+        k1 = CacheKey(
+            key=((1, (2, (3, 4, 5), 6), 7),), bindparams=[], params={}
+        )
+        k2 = CacheKey(
+            key=((1, (2, (3, 9, 5), 6), 7),), bindparams=[], params={}
+        )
 
         eq_(list(k1._whats_different(k2)), ["key[0][1][1][1]:  4 != 9"])
 
     def test_multiple_differences_nested(self):
         """Test detection of multiple differences in nested structure"""
-        k1 = CacheKey(key=((1, (2, 3), 4),), bindparams=[])
-        k2 = CacheKey(key=((1, (5, 7), 4),), bindparams=[])
+        k1 = CacheKey(key=((1, (2, 3), 4),), bindparams=[], params={})
+        k2 = CacheKey(key=((1, (5, 7), 4),), bindparams=[], params={})
 
         eq_(
             list(k1._whats_different(k2)),
@@ -2384,21 +2392,29 @@ class TestCacheKeyUtil(fixtures.TestBase):
 
     def test_diff_method(self):
         """Test the _diff() method that returns a comma-separated string"""
-        k1 = CacheKey(key=((1, (2, 3)),), bindparams=[])
-        k2 = CacheKey(key=((1, (5, 7)),), bindparams=[])
+        k1 = CacheKey(key=((1, (2, 3)),), bindparams=[], params={})
+        k2 = CacheKey(key=((1, (5, 7)),), bindparams=[], params={})
 
         eq_(k1._diff(k2), "key[0][1][0]:  2 != 5, key[0][1][1]:  3 != 7")
 
     def test_with_string_differences(self):
         """Test detection of string differences"""
-        k1 = CacheKey(key=(("name", ("x", "value")),), bindparams=[])
-        k2 = CacheKey(key=(("name", ("y", "value")),), bindparams=[])
+        k1 = CacheKey(
+            key=(("name", ("x", "value")),), bindparams=[], params={}
+        )
+        k2 = CacheKey(
+            key=(("name", ("y", "value")),), bindparams=[], params={}
+        )
 
         eq_(list(k1._whats_different(k2)), ["key[0][1][0]:  x != y"])
 
     def test_with_mixed_types(self):
         """Test detection of differences with mixed types"""
-        k1 = CacheKey(key=(("id", 1, ("nested", 100)),), bindparams=[])
-        k2 = CacheKey(key=(("id", 1, ("nested", 200)),), bindparams=[])
+        k1 = CacheKey(
+            key=(("id", 1, ("nested", 100)),), bindparams=[], params={}
+        )
+        k2 = CacheKey(
+            key=(("id", 1, ("nested", 200)),), bindparams=[], params={}
+        )
 
         eq_(list(k1._whats_different(k2)), ["key[0][2][1]:  100 != 200"])
index d044d8b57f064d78fa910da7c8a3bdc2e96d53e0..c866df3c05bd7b02603bf3afa839fa72c515d3e2 100644 (file)
@@ -46,6 +46,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
+from sqlalchemy.testing.assertions import expect_deprecated
 from sqlalchemy.testing.schema import eq_clause_element
 
 A = B = t1 = t2 = t3 = table1 = table2 = table3 = table4 = None
@@ -898,13 +899,17 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
         params = {"xb": 42, "yb": 33}
         sel = select(Y).select_from(jj).params(params)
 
+        cc = sel._generate_cache_key()
+
         eq_(
             [
-                eq_clause_element(bindparam("yb", value=33)),
-                eq_clause_element(bindparam("xb", value=42)),
+                eq_clause_element(bindparam("yb", None, Integer)),
+                eq_clause_element(bindparam("xb", None, Integer)),
             ],
-            sel._generate_cache_key()[1],
+            cc[1],
         )
+        eq_(cc[2], {"xb": 42, "yb": 33})
+        eq_(sel.compile().params, params)
 
     def test_dont_traverse_immutables(self):
         meta = MetaData()
@@ -960,7 +965,12 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
             stmt._generate_cache_key()[1],
         )
 
-        stmt = select(b).where(criteria.params({param_key: "some other data"}))
+        with expect_deprecated(
+            r"The params\(\) and unique_params\(\) methods on non-statement"
+        ):
+            stmt = select(b).where(
+                criteria.params({param_key: "some other data"})
+            )
         self.assert_compile(
             stmt,
             "SELECT b.id, b.data FROM b, (SELECT b.id AS id "
@@ -1004,7 +1014,11 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
             Ps,
             join(pe_s, s_s, and_(pe_s.c.c == s_s.c.c, pe_s.c.p == s_s.c.p)),
             and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p),
-        ).params(params)
+        )
+        with expect_deprecated(
+            r"The params\(\) and unique_params\(\) methods on non-statement"
+        ):
+            jj = jj.params(params)
 
         eq_(
             [
@@ -1039,11 +1053,13 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
 
         pe_s = select(Pe).where(Pe.c.pid == bindparam("pid")).alias("pe_s")
         s_s = select(S).where(S.c.sid == bindparam("sid")).alias("s_s")
-        jj = (
-            join(Ps, pe_s, and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p))
-            .join(s_s, and_(Ps.c.c == s_s.c.c, Ps.c.p == s_s.c.p))
-            .params(params)
+        jj = join(Ps, pe_s, and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p)).join(
+            s_s, and_(Ps.c.c == s_s.c.c, Ps.c.p == s_s.c.p)
         )
+        with expect_deprecated(
+            r"The params\(\) and unique_params\(\) methods on non-statement"
+        ):
+            jj = jj.params(params)
 
         eq_(
             [
index c64aec70646fca8867937631c58fc54ce494ca40..1aee4300ab7f9f17e8e9518d0ef0720c935d6a14 100644 (file)
@@ -59,6 +59,7 @@ from sqlalchemy.testing import config
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
+from sqlalchemy.testing.assertions import expect_deprecated
 from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.assertions import expect_warnings
 from sqlalchemy.testing.engines import all_dialects
@@ -691,12 +692,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "WHERE users.id > anon_1.z",
         )
 
-        s = select(users).where(
-            users.c.id.between(
-                calculate.alias("c1").unique_params(x=17, y=45).c.z,
-                calculate.alias("c2").unique_params(x=5, y=12).c.z,
-            ),
-        )
+        with expect_deprecated(
+            r"The params\(\) and unique_params\(\) methods on non-statement"
+        ):
+            s = select(users).where(
+                users.c.id.between(
+                    calculate.alias("c1").unique_params(x=17, y=45).c.z,
+                    calculate.alias("c2").unique_params(x=5, y=12).c.z,
+                ),
+            )
 
         self.assert_compile(
             s,
index c4c28b4f46ffeb0cb161b9f1403bfd22604e0dcb..191ad6636984f7269e955b4c7e1f528ef84ba4d5 100644 (file)
@@ -3147,7 +3147,7 @@ class AnnotationsTest(fixtures.TestBase):
         """
         user = Table("user", MetaData(), Column("id", Integer))
 
-        ids_param = bindparam("ids")
+        ids_param = bindparam("ids", -1)
 
         cte = select(user).where(user.c.id == ids_param).cte("cte")
 
@@ -3157,6 +3157,8 @@ class AnnotationsTest(fixtures.TestBase):
 
         if use_get_params:
             stmt = stmt.params(ids=17)
+            exp = -1
+            eq_(stmt._generate_cache_key()[2], {"ids": 17})
         else:
             # test without using params(), as the implementation
             # for params() will be changing
@@ -3170,13 +3172,16 @@ class AnnotationsTest(fixtures.TestBase):
                 {"maintain_key": True, "detect_subquery_cols": True},
                 {"bindparam": visit_bindparam},
             )
+            exp = 17
+            eq_(stmt._generate_cache_key()[2], None)
 
         eq_(
             stmt.selected_columns.id.table.element._where_criteria[
                 0
             ].right.value,
-            17,
+            exp,
         )
+        eq_(stmt.compile().params, {"ids": 17})
 
     def test_basic_attrs(self):
         t = Table(
diff --git a/test/sql/test_statement_params.py b/test/sql/test_statement_params.py
new file mode 100644 (file)
index 0000000..d4f2bf4
--- /dev/null
@@ -0,0 +1,199 @@
+import random
+
+from sqlalchemy import testing
+from sqlalchemy.schema import Column
+from sqlalchemy.sql import bindparam
+from sqlalchemy.sql import column
+from sqlalchemy.sql import dml
+from sqlalchemy.sql import func
+from sqlalchemy.sql import select
+from sqlalchemy.sql import text
+from sqlalchemy.sql.base import ExecutableStatement
+from sqlalchemy.sql.elements import literal
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_deprecated
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_not
+from sqlalchemy.testing import ne_
+from sqlalchemy.testing.schema import Table
+from sqlalchemy.types import Integer
+from sqlalchemy.types import Text
+from sqlalchemy.util.langhelpers import class_hierarchy
+
+
+class BasicTests(fixtures.TestBase):
+    def _all_subclasses(self, cls_):
+        return dict.fromkeys(
+            s
+            for s in class_hierarchy(cls_)
+            # class_hierarchy may return values that
+            # aren't subclasses of cls
+            if issubclass(s, cls_)
+        )
+
+    @staticmethod
+    def _relevant_impls():
+        return (
+            text("select 1 + 2"),
+            text("select 42 as q").columns(column("q", Integer)),
+            func.max(42),
+            select(1, 2).union(select(3, 4)),
+            select(1, 2),
+        )
+
+    def test_params_impl(self):
+        exclude = (dml.UpdateBase,)
+        visit_names = set()
+        for cls_ in self._all_subclasses(ExecutableStatement):
+            if not issubclass(cls_, exclude):
+                if "__visit_name__" in cls_.__dict__:
+                    visit_names.add(cls_.__visit_name__)
+                eq_(cls_.params, ExecutableStatement.params, cls_)
+            else:
+                ne_(cls_.params, ExecutableStatement.params, cls_)
+                for other in exclude:
+                    if issubclass(cls_, other):
+                        eq_(cls_.params, other.params, cls_)
+                        break
+                else:
+                    assert False
+
+        extra = {"orm_from_statement"}
+        eq_(
+            visit_names - extra,
+            {i.__visit_name__ for i in self._relevant_impls()},
+        )
+
+    @testing.combinations(*_relevant_impls())
+    def test_compile_params(self, impl):
+        new = impl.params(foo=5, bar=10)
+        is_not(new, impl)
+        eq_(impl.compile()._collected_params, {})
+        eq_(new.compile()._collected_params, {"foo": 5, "bar": 10})
+        eq_(new._generate_cache_key()[2], {"foo": 5, "bar": 10})
+
+
+class CacheTests(fixtures.TablesTest):
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("a", metadata, Column("data", Integer))
+        Table("b", metadata, Column("data", Text))
+
+    @classmethod
+    def insert_data(cls, connection):
+        connection.execute(
+            cls.tables.a.insert(),
+            [{"data": i} for i in range(1, 11)],
+        )
+        connection.execute(
+            cls.tables.b.insert(),
+            [{"data": "row %d" % i} for i in range(1, 11)],
+        )
+
+    def test_plain_select(self, connection):
+        a = self.tables.a
+
+        cs = connection.scalars
+
+        for _ in range(3):
+            x1 = random.randint(1, 10)
+
+            eq_(cs(select(a).where(a.c.data == x1)).all(), [x1])
+            stmt = select(a).where(a.c.data == bindparam("x", x1))
+            eq_(cs(stmt).all(), [x1])
+
+            x1 = random.randint(1, 10)
+            eq_(cs(stmt.params({"x": x1})).all(), [x1])
+
+            x1 = random.randint(1, 10)
+            eq_(cs(stmt, {"x": x1}).all(), [x1])
+
+            x1 = random.randint(1, 10)
+            x2 = random.randint(1, 10)
+            eq_(cs(stmt.params({"x": x1}), {"x": x2}).all(), [x2])
+
+            stmt2 = stmt.params(x=6).subquery().select()
+            eq_(cs(stmt2).all(), [6])
+            eq_(cs(stmt2.params({"x": 2})).all(), [2])
+
+            with expect_deprecated(
+                r"The params\(\) and unique_params\(\) "
+                "methods on non-statement"
+            ):
+                # NOTE: can't mix and match the two params styles here
+                stmt3 = stmt.params(x=6).subquery().params(x=8).select()
+            eq_(cs(stmt3).all(), [6])
+            eq_(cs(stmt3.params({"x": 9})).all(), [9])
+
+    def test_union(self, connection):
+        a = self.tables.a
+
+        cs = connection.scalars
+        for _ in range(3):
+            x1 = random.randint(1, 10)
+            x2 = random.randint(1, 10)
+
+            eq_(
+                cs(
+                    select(a)
+                    .where(a.c.data == x1)
+                    .union_all(select(a).where(a.c.data == x2))
+                    .order_by(a.c.data)
+                ).all(),
+                sorted([x1, x2]),
+            )
+
+            x1 = random.randint(1, 10)
+            x2 = random.randint(1, 10)
+            stmt = (
+                select(a, literal(1).label("ord"))
+                .where(a.c.data == bindparam("x", x1))
+                .union_all(
+                    select(a, literal(2)).where(a.c.data == bindparam("y", x2))
+                )
+                .order_by("ord")
+            )
+            eq_(cs(stmt).all(), [x1, x2])
+
+            x1a = random.randint(1, 10)
+            eq_(cs(stmt.params({"x": x1a})).all(), [x1a, x2])
+
+            x2 = random.randint(1, 10)
+            eq_(cs(stmt, {"y": x2}).all(), [x1, x2])
+
+            x1 = random.randint(1, 10)
+            x2 = random.randint(1, 10)
+            eq_(cs(stmt.params({"x": x1}), {"y": x2}).all(), [x1, x2])
+
+            x1 = random.randint(1, 10)
+            x2 = random.randint(1, 10)
+            stmt2 = (
+                stmt.params(x=x1)
+                .subquery()
+                .select()
+                .params(y=x2)
+                .order_by("ord")
+            )
+            eq_(cs(stmt2).all(), [x1, x2])
+            eq_(cs(stmt2.params({"x": x1}).params({"y": x2})).all(), [x1, x2])
+
+    def test_text(self, connection):
+        a = self.tables.a
+
+        cs = connection.scalars
+
+        for _ in range(3):
+            x0 = random.randint(1, 10)
+            stmt = text("select data from a where data = :x").params(x=x0)
+            eq_(cs(stmt).all(), [x0])
+
+            x1 = random.randint(1, 10)
+            eq_(cs(stmt.params({"x": x1})).all(), [x1])
+
+            x2 = random.randint(1, 10)
+            stmt2 = stmt.columns(a.c.data).params(x=x2)
+            eq_(cs(stmt2).all(), [x2])
+            eq_(cs(stmt2, {"x": 1}).all(), [1])
+            eq_(cs(stmt2.params(x=1)).all(), [1])