New Features and Improvements - Core
=====================================
-
.. _change_10635:
``Row`` now represents individual column types directly without ``Tuple``
:ticket:`11234`
+.. _change_7066:
+
+Improved ``params()`` implementation for executable statements
+--------------------------------------------------------------
+
+The :meth:`_sql.ClauseElement.params` and :meth:`_sql.ClauseElement.unique_params`
+methods have been deprecated in favor of a new implementation on executable
+statements that provides improved performance and better integration with
+ORM-enabled statements.
+
+Executable statement objects like :class:`_sql.Select`, :class:`_sql.CompoundSelect`,
+and :class:`_sql.TextClause` now provide an improved :meth:`_sql.ExecutableStatement.params`
+method that avoids a full cloned traversal of the statement tree. Instead, parameters
+are stored directly on the statement object and efficiently merged during compilation
+and/or cache key traversal.
+
+The new implementation provides several benefits:
+
+* **Better performance** - Parameters are stored in a simple dictionary rather than
+ requiring a full statement tree traversal with cloning
+* **Proper caching integration** - Parameters are correctly integrated into SQLAlchemy's
+ cache key system via ``_generate_cache_key()``
+* **ORM statement compatibility** - Works correctly with ORM-enabled statements, including
+ ORM entities used with :func:`_orm.aliased`, subqueries, CTEs, etc.
+
+Use of :meth:`_sql.ExecutableStatement.params` is unchanged, provided the given
+object is a statement object such as :func:`_sql.select`::
+
+ stmt = select(table).where(table.c.data == bindparam("x"))
+
+ # Execute with parameter value
+ result = connection.execute(stmt.params(x=5))
+
+ # Can be chained and used in subqueries
+ stmt2 = stmt.params(x=6).subquery().select()
+ result = connection.execute(stmt2.params(x=7)) # Uses x=7
+
+The deprecated :meth:`_sql.ClauseElement.params` and :meth:`_sql.ClauseElement.unique_params`
+methods on non-executable elements like :class:`_sql.ColumnElement` and general
+:class:`_sql.ClauseElement` instances will continue to work during the deprecation
+period but will emit deprecation warnings.
+
+:ticket:`7066`
+
+
.. _change_4950:
CREATE TABLE AS SELECT Support
--- /dev/null
+.. change::
+ :tags: change, sql
+ :tickets: 7066, 12915
+
+ Added new implementation for the :meth:`.Select.params` method and that of
+ similar statements, via a new statement-only
+ :meth:`.ExecutableStatement.params` method which works more efficiently and
+ correctly than the previous implementations available from
+ :class:`.ClauseElement`, by assocating the given parameter dictionary with
+ the statement overall rather than cloning the statement and rewriting its
+ bound parameters. The :meth:`_sql.ClauseElement.params` and
+ :meth:`_sql.ClauseElement.unique_params` methods, when called on an object
+ that does not implement :class:`.ExecutableStatement`, will continue to
+ work the old way of cloning the object, and will emit a deprecation
+ warning. This issue both resolves the architectural / performance
+ concerns of :ticket:`7066` and also provides correct ORM compatibility for
+ functions like :func:`_orm.aliased`, reported by :ticket:`12915`.
+
+ .. seealso::
+
+ :ref:`change_7066`
.. autoclass:: Executable
:members:
+.. autoclass:: ExecutableStatement
+ :members:
+
.. autoclass:: Exists
:members:
"compiled_cache", self.engine._compiled_cache
)
- compiled_sql, extracted_params, cache_hit = elem._compile_w_cache(
- dialect=dialect,
- compiled_cache=compiled_cache,
- column_keys=keys,
- for_executemany=for_executemany,
- schema_translate_map=schema_translate_map,
- linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
+ compiled_sql, extracted_params, param_dict, cache_hit = (
+ elem._compile_w_cache(
+ dialect=dialect,
+ compiled_cache=compiled_cache,
+ column_keys=keys,
+ for_executemany=for_executemany,
+ schema_translate_map=schema_translate_map,
+ linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
+ )
)
ret = self._execute_context(
dialect,
elem,
extracted_params,
cache_hit=cache_hit,
+ param_dict=param_dict,
)
if has_events:
self.dispatch.after_execute(
invoked_statement: Executable,
extracted_parameters: Optional[Sequence[BindParameter[Any]]],
cache_hit: CacheStats = CacheStats.CACHING_DISABLED,
+ param_dict: _CoreSingleExecuteParams | None = None,
) -> ExecutionContext:
"""Initialize execution context for a Compiled construct."""
compiled.construct_params(
extracted_parameters=extracted_parameters,
escape_names=False,
+ _collected_params=param_dict,
)
]
else:
escape_names=False,
_group_number=grp,
extracted_parameters=extracted_parameters,
+ _collected_params=param_dict,
)
for grp, m in enumerate(parameters)
]
from ..sql.base import CacheableOptions
from ..sql.base import CompileState
from ..sql.base import Executable
+from ..sql.base import ExecutableStatement
from ..sql.base import Generative
from ..sql.base import Options
from ..sql.dml import UpdateBase
_traverse_internals = [
("_raw_columns", InternalTraversal.dp_clauseelement_list),
("element", InternalTraversal.dp_clauseelement),
- ] + Executable._executable_traverse_internals
+ ] + ExecutableStatement._executable_traverse_internals
_cache_key_traversal = _traverse_internals + [
("_compile_options", InternalTraversal.dp_has_cache_key)
from ._typing import NotNullable as NotNullable
from ._typing import Nullable as Nullable
from .base import Executable as Executable
+from .base import ExecutableStatement as ExecutableStatement
from .base import SyntaxExtension as SyntaxExtension
from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS
from .compiler import FROM_LINTING as FROM_LINTING
if TYPE_CHECKING:
from .cache_key import CacheConst
+ from ..engine.interfaces import _CoreSingleExecuteParams
# START GENERATED CYTHON IMPORT
# This section is automatically generated by the script tools/cython_imports.py
return value
+_AM_KEY = Union[int, str, "CacheConst"]
+_AM_VALUE = Union[int, Literal[True], "_CoreSingleExecuteParams"]
+
+
@cython.cclass
-class anon_map(
- Dict[
- Union[int, str, "Literal[CacheConst.NO_CACHE]"],
- Union[int, Literal[True]],
- ]
-):
+class anon_map(Dict[_AM_KEY, _AM_VALUE]):
"""A map that creates new keys for missing key access.
Produces an incrementing sequence given a series of unique keys.
@cython.cfunc # type:ignore[misc]
@cython.inline # type:ignore[misc]
- def _add_missing(
- self: anon_map, key: Union[int, str, "Literal[CacheConst.NO_CACHE]"], /
- ) -> int:
+ def _add_missing(self: anon_map, key: _AM_KEY, /) -> int:
val: int = self._index
self._index += 1
self_dict: dict = self # type: ignore[type-arg]
if cython.compiled:
- def __getitem__(
- self: anon_map,
- key: Union[int, str, "Literal[CacheConst.NO_CACHE]"],
- /,
- ) -> Union[int, Literal[True]]:
+ def __getitem__(self: anon_map, key: _AM_KEY, /) -> _AM_VALUE:
self_dict: dict = self # type: ignore[type-arg]
if key in self_dict:
else:
return self._add_missing(key) # type:ignore[no-any-return]
- def __missing__(
- self: anon_map, key: Union[int, str, "Literal[CacheConst.NO_CACHE]"], /
- ) -> int:
+ def __missing__(self: anon_map, key: _AM_KEY, /) -> int:
return self._add_missing(key) # type:ignore[no-any-return]
from .. import event
from .. import exc
from .. import util
+from ..util import EMPTY_DICT
from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util.typing import Self
from ..engine import Connection
from ..engine import CursorResult
from ..engine.interfaces import _CoreMultiExecuteParams
+ from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _ExecuteOptions
from ..engine.interfaces import _ImmutableExecuteOptions
from ..engine.interfaces import CacheStats
@hybridmethod
def _generate_cache_key(self) -> Optional[CacheKey]:
- return HasCacheKey._generate_cache_key_for_object(self)
+ return HasCacheKey._generate_cache_key(self)
class ExecutableOption(HasCopyInternals):
for_executemany: bool = False,
schema_translate_map: Optional[SchemaTranslateMapType] = None,
**kw: Any,
- ) -> Tuple[
- Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats
+ ) -> tuple[
+ Compiled,
+ Sequence[BindParameter[Any]] | None,
+ _CoreSingleExecuteParams | None,
+ CacheStats,
]: ...
def _execute_on_connection(
return self._execution_options
+class ExecutableStatement(Executable):
+ """Executable subclass that implements a lightweight version of ``params``
+ that avoids a full cloned traverse.
+
+ .. versionadded:: 2.1
+
+ """
+
+ _params: util.immutabledict[str, Any] = EMPTY_DICT
+
+ _executable_traverse_internals = (
+ Executable._executable_traverse_internals
+ + [("_params", InternalTraversal.dp_params)]
+ )
+
+ @_generative
+ def params(
+ self,
+ __optionaldict: _CoreSingleExecuteParams | None = None,
+ /,
+ **kwargs: Any,
+ ) -> Self:
+ """Return a copy with the provided bindparam values.
+
+ Returns a copy of this Executable with bindparam values set
+ to the given dictionary::
+
+ >>> clause = column("x") + bindparam("foo")
+ >>> print(clause.compile().params)
+ {'foo': None}
+ >>> print(clause.params({"foo": 7}).compile().params)
+ {'foo': 7}
+
+ """
+ if __optionaldict:
+ kwargs.update(__optionaldict)
+ self._params = (
+ util.immutabledict(kwargs)
+ if not self._params
+ else self._params | kwargs
+ )
+ return self
+
+
class SchemaEventTarget(event.EventTarget):
"""Base class for elements that are the targets of :class:`.DDLEvents`
events.
from typing import Any
from typing import Callable
from typing import Dict
+from typing import Final
from typing import Iterable
from typing import Iterator
from typing import List
class CacheConst(enum.Enum):
NO_CACHE = 0
+ PARAMS = 1
-NO_CACHE = CacheConst.NO_CACHE
+NO_CACHE: Final = CacheConst.NO_CACHE
_CacheKeyTraversalType = Union[
return None
else:
assert key is not None
- return CacheKey(key, bindparams)
-
- @classmethod
- def _generate_cache_key_for_object(
- cls, obj: HasCacheKey
- ) -> Optional[CacheKey]:
- bindparams: List[BindParameter[Any]] = []
-
- _anon_map = anon_map()
- key = obj._gen_cache_key(_anon_map, bindparams)
- if NO_CACHE in _anon_map:
- return None
- else:
- assert key is not None
- return CacheKey(key, bindparams)
+ return CacheKey(
+ key,
+ bindparams,
+ _anon_map.get(CacheConst.PARAMS), # type: ignore[arg-type]
+ )
class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey):
key: Tuple[Any, ...]
bindparams: Sequence[BindParameter[Any]]
+ params: _CoreSingleExecuteParams | None
# can't set __hash__ attribute because it interferes
# with namedtuple
@classmethod
def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str:
- ck1 = CacheKey(left, [])
- ck2 = CacheKey(right, [])
+ ck1 = CacheKey(left, [], None)
+ ck2 = CacheKey(right, [], None)
return ck1._diff(ck2)
def _whats_different(self, other: CacheKey) -> Iterator[str]:
anon_map[NO_CACHE] = True
return ()
+ def visit_params(
+ self,
+ attrname: str,
+ obj: Any,
+ parent: Any,
+ anon_map: anon_map,
+ bindparams: List[BindParameter[Any]],
+ ) -> Tuple[Any, ...]:
+ if obj:
+ if CacheConst.PARAMS in anon_map:
+ to_set = anon_map[CacheConst.PARAMS] | obj
+ else:
+ to_set = obj
+ anon_map[CacheConst.PARAMS] = to_set
+ return ()
+
_cache_key_traversal_visitor = _CacheKeyTraversal()
from typing import cast
from typing import ClassVar
from typing import Dict
+from typing import Final
from typing import FrozenSet
from typing import Iterable
from typing import Iterator
from .base import _AmbiguousTableNameMap
from .base import CompileState
from .base import Executable
+ from .base import ExecutableStatement
from .cache_key import CacheKey
from .ddl import CreateTableAs
from .ddl import ExecutableDDLElement
_positional_pattern = re.compile(
f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
)
+ _collect_params: Final[bool]
+ _collected_params: util.immutabledict[str, Any]
@classmethod
def _init_compiler_cls(cls):
# dialect.label_length or dialect.max_identifier_length
self.truncated_names: Dict[Tuple[str, str], str] = {}
self._truncated_counters: Dict[str, int] = {}
+ if not cache_key:
+ self._collect_params = True
+ self._collected_params = util.EMPTY_DICT
+ else:
+ self._collect_params = False # type: ignore[misc]
Compiled.__init__(self, dialect, statement, **kwargs)
def _global_attributes(self) -> Dict[Any, Any]:
return {}
+ def _add_to_params(self, item: ExecutableStatement) -> None:
+ # assumes that this is called before traversing the statement
+ # so the call happens outer to inner, meaning that existing params
+ # take precedence
+ if item._params:
+ self._collected_params = item._params | self._collected_params
+
@util.memoized_instancemethod
def _init_cte_state(self) -> MutableMapping[CTE, str]:
"""Initialize collections related to CTEs only if
_group_number: Optional[int] = None,
_check: bool = True,
_no_postcompile: bool = False,
+ _collected_params: _CoreSingleExecuteParams | None = None,
) -> _MutableCoreSingleExecuteParams:
"""return a dictionary of bind parameter keys and values"""
+ if _collected_params is not None:
+ assert not self._collect_params
+ elif self._collect_params:
+ _collected_params = self._collected_params
+
+ if _collected_params:
+ if not params:
+ params = _collected_params
+ else:
+ params = {**_collected_params, **params}
if self._render_postcompile and not _no_postcompile:
assert self._post_compile_expanded_state is not None
return text
def visit_textclause(self, textclause, add_to_result_map=None, **kw):
+ if self._collect_params:
+ self._add_to_params(textclause)
+
def do_bindparam(m):
name = m.group(1)
if name in textclause._bindparams:
def visit_textual_select(
self, taf, compound_index=None, asfrom=False, **kw
):
+ if self._collect_params:
+ self._add_to_params(taf)
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
add_to_result_map: Optional[_ResultMapAppender] = None,
**kwargs: Any,
) -> str:
+ if self._collect_params:
+ self._add_to_params(func)
if add_to_result_map is not None:
add_to_result_map(func.name, func.name, (func.name,), func.type)
def visit_compound_select(
self, cs, asfrom=False, compound_index=None, **kwargs
):
+ if self._collect_params:
+ self._add_to_params(cs)
toplevel = not self.stack
compile_state = cs._compile_state_factory(cs, self, **kwargs)
"the translate_select_structure hook for structural "
"translations of SELECT objects"
)
+ if self._collect_params:
+ self._add_to_params(select_stmt)
# initial setup of SELECT. the compile_state_factory may now
# be creating a totally different SELECT from the one that was
from .schema import Sequence as Sequence # noqa: F401
from .schema import Table
from ..engine.base import Connection
+ from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import CacheStats
from ..engine.interfaces import CompiledCacheType
from ..engine.interfaces import Dialect
for_executemany: bool = False,
schema_translate_map: Optional[SchemaTranslateMapType] = None,
**kw: Any,
- ) -> Tuple[
- Compiled, Optional[typing_Sequence[BindParameter[Any]]], CacheStats
+ ) -> tuple[
+ Compiled,
+ typing_Sequence[BindParameter[Any]] | None,
+ _CoreSingleExecuteParams | None,
+ CacheStats,
]:
raise NotImplementedError()
from .base import CompileState
from .base import DialectKWArgs
from .base import Executable
+from .base import ExecutableStatement
from .base import Generative
from .base import HasCompileState
from .base import HasSyntaxExtensions
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
- + Executable._executable_traverse_internals
+ + ExecutableStatement._executable_traverse_internals
+ HasCTE._has_ctes_traverse_internals
)
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
- + Executable._executable_traverse_internals
+ + ExecutableStatement._executable_traverse_internals
+ HasCTE._has_ctes_traverse_internals
)
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
- + Executable._executable_traverse_internals
+ + ExecutableStatement._executable_traverse_internals
+ HasCTE._has_ctes_traverse_internals
)
from .base import _generative
from .base import _NoArg
from .base import Executable
+from .base import ExecutableStatement
from .base import Generative
from .base import HasMemoized
from .base import Immutable
from .. import exc
from .. import inspection
from .. import util
+from ..util import deprecated
from ..util import HasMemoized_ro_memoized_attribute
from ..util import TypingOnly
from ..util.typing import Self
from ..engine import Connection
from ..engine import Dialect
from ..engine.interfaces import _CoreMultiExecuteParams
+ from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import CacheStats
from ..engine.interfaces import CompiledCacheType
from ..engine.interfaces import CoreExecuteOptionsParameter
"""
return self._replace_params(False, __optionaldict, kwargs)
+ @deprecated(
+ "2.1",
+ "The params() and unique_params() methods on non-statement "
+ "ClauseElement objects are deprecated; params() is now limited to "
+ "statement level objects such as select(), insert(), union(), etc. ",
+ )
def _replace_params(
self,
unique: bool,
for_executemany: bool = False,
schema_translate_map: Optional[SchemaTranslateMapType] = None,
**kw: Any,
- ) -> typing_Tuple[
- Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats
+ ) -> tuple[
+ Compiled,
+ Sequence[BindParameter[Any]] | None,
+ _CoreSingleExecuteParams | None,
+ CacheStats,
]:
elem_cache_key: Optional[CacheKey]
if TYPE_CHECKING:
assert compiled_cache is not None
- cache_key, extracted_params = elem_cache_key
+ cache_key, extracted_params, param_dict = elem_cache_key
key = (
dialect,
cache_key,
schema_translate_map=schema_translate_map,
**kw,
)
+ # ensure that params of the current statement are not
+ # left in the cache
+ assert not compiled_sql._collect_params # type: ignore[attr-defined] # noqa: E501
compiled_cache[key] = compiled_sql
else:
cache_hit = dialect.CACHE_HIT
else:
+ param_dict = None
extracted_params = None
compiled_sql = self._compiler(
dialect,
- cache_key=elem_cache_key,
+ cache_key=None,
column_keys=column_keys,
for_executemany=for_executemany,
schema_translate_map=schema_translate_map,
**kw,
)
+ # here instead the params need to be extracted, since we don't
+ # have them otherwise
+ assert compiled_sql._collect_params # type: ignore[attr-defined] # noqa: E501
if not dialect._supports_statement_cache:
cache_hit = dialect.NO_DIALECT_SUPPORT
else:
cache_hit = dialect.NO_CACHE_KEY
- return compiled_sql, extracted_params, cache_hit
+ return compiled_sql, extracted_params, param_dict, cache_hit
def __invert__(self):
# undocumented element currently used by the ORM for
roles.SelectStatementRole,
roles.InElementRole,
Generative,
- Executable,
+ ExecutableStatement,
DQLDMLClauseElement,
roles.BinaryElementRole[Any],
inspection.Inspectable["TextClause"],
_traverse_internals: _TraverseInternalsType = [
("_bindparams", InternalTraversal.dp_string_clauseelement_dict),
("text", InternalTraversal.dp_string),
- ] + Executable._executable_traverse_internals
+ ] + ExecutableStatement._executable_traverse_internals
_is_text_clause = True
from .base import _select_iterables as _select_iterables
from .base import ColumnCollection as ColumnCollection
from .base import Executable as Executable
+from .base import ExecutableStatement as ExecutableStatement
from .cache_key import CacheKey as CacheKey
from .dml import Delete as Delete
from .dml import Insert as Insert
from ._typing import is_table_value_type
from .base import _entity_namespace
from .base import ColumnCollection
-from .base import Executable
+from .base import ExecutableStatement
from .base import Generative
from .base import HasMemoized
from .elements import _type_from_args
reg[identifier] = fn
-class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
+class FunctionElement(
+ ColumnElement[_T], ExecutableStatement, FromClause, Generative
+):
"""Base for SQL function-oriented constructs.
This is a `generic type <https://peps.python.org/pep-0484/#generics>`_,
("clause_expr", InternalTraversal.dp_clauseelement),
("_with_ordinality", InternalTraversal.dp_boolean),
("_table_value_type", InternalTraversal.dp_has_cache_key),
- ] + Executable._executable_traverse_internals
+ ] + ExecutableStatement._executable_traverse_internals
packagenames: Tuple[str, ...] = ()
from . import visitors
from .base import _clone
from .base import Executable
+from .base import ExecutableStatement
from .base import Options
from .cache_key import CacheConst
from .operators import ColumnOperators
class StatementLambdaElement(
- roles.AllowsLambdaRole, LambdaElement, Executable
+ roles.AllowsLambdaRole, ExecutableStatement, LambdaElement
):
"""Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
from .base import DedupeColumnCollection
from .base import DialectKWArgs
from .base import Executable
+from .base import ExecutableStatement
from .base import Generative
from .base import HasCompileState
from .base import HasMemoized
raise NotImplementedError()
-class ExecutableReturnsRows(Executable, ReturnsRows):
+class ExecutableReturnsRows(ExecutableStatement, ReturnsRows):
"""base for executable statements that return rows."""
+ SupportsCloneAnnotations._clone_annotations_traverse_internals
+ HasCTE._has_ctes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
- + Executable._executable_traverse_internals
+ + ExecutableStatement._executable_traverse_internals
)
selects: List[SelectBase]
+ HasSuffixes._has_suffixes_traverse_internals
+ HasHints._has_hints_traverse_internals
+ SupportsCloneAnnotations._clone_annotations_traverse_internals
- + Executable._executable_traverse_internals
+ + ExecutableStatement._executable_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
)
]
+ SupportsCloneAnnotations._clone_annotations_traverse_internals
+ HasCTE._has_ctes_traverse_internals
- + Executable._executable_traverse_internals
+ + ExecutableStatement._executable_traverse_internals
)
_is_textual = True
):
return COMPARE_FAILED
+ def visit_params(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ return left == right
+
def compare_expression_clauselist(self, left, right, **kw):
if left.operator is right.operator:
if operators.is_associative(left.operator):
"""Visit a list of inspectable objects which upon inspection are
HasCacheKey objects."""
+ dp_params = "PM"
+ """Visit the _params collection of ExecutableStatement"""
+
_TraverseInternalsType = List[Tuple[str, InternalTraversal]]
"""a structure that defines how a HasTraverseInternals should be
from sqlalchemy.orm import with_polymorphic
from sqlalchemy.sql import and_
from sqlalchemy.sql import sqltypes
+from sqlalchemy.sql import visitors
from sqlalchemy.sql.selectable import Join as core_join
from sqlalchemy.sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY
from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
checkparams={"param_1": 6, "param_2": 5},
)
+ @testing.variation("use_get_params", [True, False])
+ def test_annotated_cte_params_traverse(self, use_get_params):
+ """test #12915
+
+ Tests the original issue in #12915 which was a specific issue
+ involving cloned_traverse with Annotated subclasses, where traversal
+ would not properly cover a CTE's self-referential structure.
+
+ This case still does not work in the general ORM case, so the
+ implementation of .params() was changed to not rely upon
+ cloned_traversal.
+
+ """
+ User = self.classes.User
+ ids_param = bindparam("ids")
+ cte = select(User).where(User.id == ids_param).cte("cte")
+ ca = cte._annotate({"foo": "bar"})
+ stmt = select(ca)
+ if use_get_params:
+ stmt = stmt.params(ids=17)
+ else:
+ # test without using params(), in case the implementation
+ # for params() changes we still want to test cloned_traverse
+ def visit_bindparam(bind):
+ if bind.key == "ids":
+ bind.value = 17
+ bind.required = False
+
+ stmt = visitors.cloned_traverse(
+ stmt,
+ {"maintain_key": True, "detect_subquery_cols": True},
+ {"bindparam": visit_bindparam},
+ )
+ self.assert_compile(
+ stmt,
+ "WITH cte AS (SELECT users.id AS id, users.name AS name "
+ "FROM users WHERE users.id = :ids) "
+ "SELECT cte.id, cte.name FROM cte",
+ checkparams={"ids": 17},
+ )
+
+ def test_orm_cte_with_params(self, connection):
+ """test for #12915's new implementation"""
+ User = self.classes.User
+ ids_param = bindparam("ids")
+ cte = select(User).where(User.id == ids_param).cte("cte")
+ stmt = select(aliased(User, cte.alias("a1"), adapt_on_names=True))
+
+ res = connection.execute(stmt, {"ids": 7}).all()
+ eq_(res, [(7, "jack")])
+ with Session(connection) as s:
+ res = s.scalars(stmt, {"ids": 7}).all()
+ eq_(res, [User(id=7, name="jack")])
+
class PropagateAttrsTest(QueryTest):
__backend__ = True
eq_(
re.compile(r"[\n\s]+", re.M).sub(
" ",
- str(CacheKey(key=((1, (2, 7, 4), 5),), bindparams=[])),
+ str(
+ CacheKey(
+ key=((1, (2, 7, 4), 5),), bindparams=[], params={}
+ )
+ ),
),
"CacheKey(key=( ( 1, ( 2, 7, 4, ), 5, ), ),)",
)
def test_nested_tuple_difference(self):
"""Test difference detection in nested tuples"""
- k1 = CacheKey(key=((1, (2, 3, 4), 5),), bindparams=[])
- k2 = CacheKey(key=((1, (2, 7, 4), 5),), bindparams=[])
+ k1 = CacheKey(key=((1, (2, 3, 4), 5),), bindparams=[], params={})
+ k2 = CacheKey(key=((1, (2, 7, 4), 5),), bindparams=[], params={})
eq_(list(k1._whats_different(k2)), ["key[0][1][1]: 3 != 7"])
def test_deeply_nested_tuple_difference(self):
"""Test difference detection in deeply nested tuples"""
- k1 = CacheKey(key=((1, (2, (3, 4, 5), 6), 7),), bindparams=[])
- k2 = CacheKey(key=((1, (2, (3, 9, 5), 6), 7),), bindparams=[])
+ k1 = CacheKey(
+ key=((1, (2, (3, 4, 5), 6), 7),), bindparams=[], params={}
+ )
+ k2 = CacheKey(
+ key=((1, (2, (3, 9, 5), 6), 7),), bindparams=[], params={}
+ )
eq_(list(k1._whats_different(k2)), ["key[0][1][1][1]: 4 != 9"])
def test_multiple_differences_nested(self):
"""Test detection of multiple differences in nested structure"""
- k1 = CacheKey(key=((1, (2, 3), 4),), bindparams=[])
- k2 = CacheKey(key=((1, (5, 7), 4),), bindparams=[])
+ k1 = CacheKey(key=((1, (2, 3), 4),), bindparams=[], params={})
+ k2 = CacheKey(key=((1, (5, 7), 4),), bindparams=[], params={})
eq_(
list(k1._whats_different(k2)),
def test_diff_method(self):
"""Test the _diff() method that returns a comma-separated string"""
- k1 = CacheKey(key=((1, (2, 3)),), bindparams=[])
- k2 = CacheKey(key=((1, (5, 7)),), bindparams=[])
+ k1 = CacheKey(key=((1, (2, 3)),), bindparams=[], params={})
+ k2 = CacheKey(key=((1, (5, 7)),), bindparams=[], params={})
eq_(k1._diff(k2), "key[0][1][0]: 2 != 5, key[0][1][1]: 3 != 7")
def test_with_string_differences(self):
"""Test detection of string differences"""
- k1 = CacheKey(key=(("name", ("x", "value")),), bindparams=[])
- k2 = CacheKey(key=(("name", ("y", "value")),), bindparams=[])
+ k1 = CacheKey(
+ key=(("name", ("x", "value")),), bindparams=[], params={}
+ )
+ k2 = CacheKey(
+ key=(("name", ("y", "value")),), bindparams=[], params={}
+ )
eq_(list(k1._whats_different(k2)), ["key[0][1][0]: x != y"])
def test_with_mixed_types(self):
"""Test detection of differences with mixed types"""
- k1 = CacheKey(key=(("id", 1, ("nested", 100)),), bindparams=[])
- k2 = CacheKey(key=(("id", 1, ("nested", 200)),), bindparams=[])
+ k1 = CacheKey(
+ key=(("id", 1, ("nested", 100)),), bindparams=[], params={}
+ )
+ k2 = CacheKey(
+ key=(("id", 1, ("nested", 200)),), bindparams=[], params={}
+ )
eq_(list(k1._whats_different(k2)), ["key[0][2][1]: 100 != 200"])
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_not
+from sqlalchemy.testing.assertions import expect_deprecated
from sqlalchemy.testing.schema import eq_clause_element
A = B = t1 = t2 = t3 = table1 = table2 = table3 = table4 = None
params = {"xb": 42, "yb": 33}
sel = select(Y).select_from(jj).params(params)
+ cc = sel._generate_cache_key()
+
eq_(
[
- eq_clause_element(bindparam("yb", value=33)),
- eq_clause_element(bindparam("xb", value=42)),
+ eq_clause_element(bindparam("yb", None, Integer)),
+ eq_clause_element(bindparam("xb", None, Integer)),
],
- sel._generate_cache_key()[1],
+ cc[1],
)
+ eq_(cc[2], {"xb": 42, "yb": 33})
+ eq_(sel.compile().params, params)
def test_dont_traverse_immutables(self):
meta = MetaData()
stmt._generate_cache_key()[1],
)
- stmt = select(b).where(criteria.params({param_key: "some other data"}))
+ with expect_deprecated(
+ r"The params\(\) and unique_params\(\) methods on non-statement"
+ ):
+ stmt = select(b).where(
+ criteria.params({param_key: "some other data"})
+ )
self.assert_compile(
stmt,
"SELECT b.id, b.data FROM b, (SELECT b.id AS id "
Ps,
join(pe_s, s_s, and_(pe_s.c.c == s_s.c.c, pe_s.c.p == s_s.c.p)),
and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p),
- ).params(params)
+ )
+ with expect_deprecated(
+ r"The params\(\) and unique_params\(\) methods on non-statement"
+ ):
+ jj = jj.params(params)
eq_(
[
pe_s = select(Pe).where(Pe.c.pid == bindparam("pid")).alias("pe_s")
s_s = select(S).where(S.c.sid == bindparam("sid")).alias("s_s")
- jj = (
- join(Ps, pe_s, and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p))
- .join(s_s, and_(Ps.c.c == s_s.c.c, Ps.c.p == s_s.c.p))
- .params(params)
+ jj = join(Ps, pe_s, and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p)).join(
+ s_s, and_(Ps.c.c == s_s.c.c, Ps.c.p == s_s.c.p)
)
+ with expect_deprecated(
+ r"The params\(\) and unique_params\(\) methods on non-statement"
+ ):
+ jj = jj.params(params)
eq_(
[
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
+from sqlalchemy.testing.assertions import expect_deprecated
from sqlalchemy.testing.assertions import expect_raises_message
from sqlalchemy.testing.assertions import expect_warnings
from sqlalchemy.testing.engines import all_dialects
"WHERE users.id > anon_1.z",
)
- s = select(users).where(
- users.c.id.between(
- calculate.alias("c1").unique_params(x=17, y=45).c.z,
- calculate.alias("c2").unique_params(x=5, y=12).c.z,
- ),
- )
+ with expect_deprecated(
+ r"The params\(\) and unique_params\(\) methods on non-statement"
+ ):
+ s = select(users).where(
+ users.c.id.between(
+ calculate.alias("c1").unique_params(x=17, y=45).c.z,
+ calculate.alias("c2").unique_params(x=5, y=12).c.z,
+ ),
+ )
self.assert_compile(
s,
"""
user = Table("user", MetaData(), Column("id", Integer))
- ids_param = bindparam("ids")
+ ids_param = bindparam("ids", -1)
cte = select(user).where(user.c.id == ids_param).cte("cte")
if use_get_params:
stmt = stmt.params(ids=17)
+ exp = -1
+ eq_(stmt._generate_cache_key()[2], {"ids": 17})
else:
# test without using params(), as the implementation
# for params() will be changing
{"maintain_key": True, "detect_subquery_cols": True},
{"bindparam": visit_bindparam},
)
+ exp = 17
+ eq_(stmt._generate_cache_key()[2], None)
eq_(
stmt.selected_columns.id.table.element._where_criteria[
0
].right.value,
- 17,
+ exp,
)
+ eq_(stmt.compile().params, {"ids": 17})
def test_basic_attrs(self):
t = Table(
--- /dev/null
+import random
+
+from sqlalchemy import testing
+from sqlalchemy.schema import Column
+from sqlalchemy.sql import bindparam
+from sqlalchemy.sql import column
+from sqlalchemy.sql import dml
+from sqlalchemy.sql import func
+from sqlalchemy.sql import select
+from sqlalchemy.sql import text
+from sqlalchemy.sql.base import ExecutableStatement
+from sqlalchemy.sql.elements import literal
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_deprecated
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_not
+from sqlalchemy.testing import ne_
+from sqlalchemy.testing.schema import Table
+from sqlalchemy.types import Integer
+from sqlalchemy.types import Text
+from sqlalchemy.util.langhelpers import class_hierarchy
+
+
+class BasicTests(fixtures.TestBase):
+ def _all_subclasses(self, cls_):
+ return dict.fromkeys(
+ s
+ for s in class_hierarchy(cls_)
+ # class_hierarchy may return values that
+ # aren't subclasses of cls
+ if issubclass(s, cls_)
+ )
+
+ @staticmethod
+ def _relevant_impls():
+ return (
+ text("select 1 + 2"),
+ text("select 42 as q").columns(column("q", Integer)),
+ func.max(42),
+ select(1, 2).union(select(3, 4)),
+ select(1, 2),
+ )
+
+ def test_params_impl(self):
+ exclude = (dml.UpdateBase,)
+ visit_names = set()
+ for cls_ in self._all_subclasses(ExecutableStatement):
+ if not issubclass(cls_, exclude):
+ if "__visit_name__" in cls_.__dict__:
+ visit_names.add(cls_.__visit_name__)
+ eq_(cls_.params, ExecutableStatement.params, cls_)
+ else:
+ ne_(cls_.params, ExecutableStatement.params, cls_)
+ for other in exclude:
+ if issubclass(cls_, other):
+ eq_(cls_.params, other.params, cls_)
+ break
+ else:
+ assert False
+
+ extra = {"orm_from_statement"}
+ eq_(
+ visit_names - extra,
+ {i.__visit_name__ for i in self._relevant_impls()},
+ )
+
+ @testing.combinations(*_relevant_impls())
+ def test_compile_params(self, impl):
+ new = impl.params(foo=5, bar=10)
+ is_not(new, impl)
+ eq_(impl.compile()._collected_params, {})
+ eq_(new.compile()._collected_params, {"foo": 5, "bar": 10})
+ eq_(new._generate_cache_key()[2], {"foo": 5, "bar": 10})
+
+
+class CacheTests(fixtures.TablesTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table("a", metadata, Column("data", Integer))
+ Table("b", metadata, Column("data", Text))
+
+ @classmethod
+ def insert_data(cls, connection):
+ connection.execute(
+ cls.tables.a.insert(),
+ [{"data": i} for i in range(1, 11)],
+ )
+ connection.execute(
+ cls.tables.b.insert(),
+ [{"data": "row %d" % i} for i in range(1, 11)],
+ )
+
+ def test_plain_select(self, connection):
+ a = self.tables.a
+
+ cs = connection.scalars
+
+ for _ in range(3):
+ x1 = random.randint(1, 10)
+
+ eq_(cs(select(a).where(a.c.data == x1)).all(), [x1])
+ stmt = select(a).where(a.c.data == bindparam("x", x1))
+ eq_(cs(stmt).all(), [x1])
+
+ x1 = random.randint(1, 10)
+ eq_(cs(stmt.params({"x": x1})).all(), [x1])
+
+ x1 = random.randint(1, 10)
+ eq_(cs(stmt, {"x": x1}).all(), [x1])
+
+ x1 = random.randint(1, 10)
+ x2 = random.randint(1, 10)
+ eq_(cs(stmt.params({"x": x1}), {"x": x2}).all(), [x2])
+
+ stmt2 = stmt.params(x=6).subquery().select()
+ eq_(cs(stmt2).all(), [6])
+ eq_(cs(stmt2.params({"x": 2})).all(), [2])
+
+ with expect_deprecated(
+ r"The params\(\) and unique_params\(\) "
+ "methods on non-statement"
+ ):
+ # NOTE: can't mix and match the two params styles here
+ stmt3 = stmt.params(x=6).subquery().params(x=8).select()
+ eq_(cs(stmt3).all(), [6])
+ eq_(cs(stmt3.params({"x": 9})).all(), [9])
+
+ def test_union(self, connection):
+ a = self.tables.a
+
+ cs = connection.scalars
+ for _ in range(3):
+ x1 = random.randint(1, 10)
+ x2 = random.randint(1, 10)
+
+ eq_(
+ cs(
+ select(a)
+ .where(a.c.data == x1)
+ .union_all(select(a).where(a.c.data == x2))
+ .order_by(a.c.data)
+ ).all(),
+ sorted([x1, x2]),
+ )
+
+ x1 = random.randint(1, 10)
+ x2 = random.randint(1, 10)
+ stmt = (
+ select(a, literal(1).label("ord"))
+ .where(a.c.data == bindparam("x", x1))
+ .union_all(
+ select(a, literal(2)).where(a.c.data == bindparam("y", x2))
+ )
+ .order_by("ord")
+ )
+ eq_(cs(stmt).all(), [x1, x2])
+
+ x1a = random.randint(1, 10)
+ eq_(cs(stmt.params({"x": x1a})).all(), [x1a, x2])
+
+ x2 = random.randint(1, 10)
+ eq_(cs(stmt, {"y": x2}).all(), [x1, x2])
+
+ x1 = random.randint(1, 10)
+ x2 = random.randint(1, 10)
+ eq_(cs(stmt.params({"x": x1}), {"y": x2}).all(), [x1, x2])
+
+ x1 = random.randint(1, 10)
+ x2 = random.randint(1, 10)
+ stmt2 = (
+ stmt.params(x=x1)
+ .subquery()
+ .select()
+ .params(y=x2)
+ .order_by("ord")
+ )
+ eq_(cs(stmt2).all(), [x1, x2])
+ eq_(cs(stmt2.params({"x": x1}).params({"y": x2})).all(), [x1, x2])
+
+ def test_text(self, connection):
+ a = self.tables.a
+
+ cs = connection.scalars
+
+ for _ in range(3):
+ x0 = random.randint(1, 10)
+ stmt = text("select data from a where data = :x").params(x=x0)
+ eq_(cs(stmt).all(), [x0])
+
+ x1 = random.randint(1, 10)
+ eq_(cs(stmt.params({"x": x1})).all(), [x1])
+
+ x2 = random.randint(1, 10)
+ stmt2 = stmt.columns(a.c.data).params(x=x2)
+ eq_(cs(stmt2).all(), [x2])
+ eq_(cs(stmt2, {"x": 1}).all(), [1])
+ eq_(cs(stmt2.params(x=1)).all(), [1])