--- /dev/null
+.. change::
+ :tags: feature, sql
+ :tickets: 12195
+
+ Added the ability to create custom SQL constructs that can define new
+ clauses within SELECT, INSERT, UPDATE, and DELETE statements without
+ needing to modify the construction or compilation code of of
+ :class:`.Select`, :class:`.Insert`, :class:`.Update`, or :class:`.Delete`
+ directly. Support for testing these constructs, including caching support,
+ is present along with an example test suite. The use case for these
+ constructs is expected to be third party dialects for analytical SQL
+ (so-called NewSQL) or other novel styles of database that introduce new
+ clauses to these statements. A new example suite is included which
+ illustrates the ``QUALIFY`` SQL construct used by several NewSQL databases
+ which includes a cachable implementation as well as a test suite.
+
+ .. seealso::
+
+ :ref:`examples.syntax_extensions`
+
.. automodule:: sqlalchemy.ext.compiler
:members:
+
+
+.. autoclass:: sqlalchemy.sql.SyntaxExtension
+ :members:
.. _examples_toplevel:
-============
-ORM Examples
-============
+=====================
+Core and ORM Examples
+=====================
The SQLAlchemy distribution includes a variety of code examples illustrating
a select set of patterns, some typical and some not so typical. All are
.. automodule:: examples.sharding
+Extending Core
+==============
+
+.. _examples_syntax_extensions:
+
+Extending Statements like SELECT, INSERT, etc
+----------------------------------------------
+
+.. automodule:: examples.syntax_extensions
+
Extending the ORM
=================
--- /dev/null
+"""
+A detailed example of extending the :class:`.Select` construct to include
+a new non-SQL standard clause ``QUALIFY``.
+
+This example illustrates both the :ref:`sqlalchemy.ext.compiler_toplevel`
+as well as an extension known as :class:`.SyntaxExtension`.
+
+.. autosource::
+
+"""
--- /dev/null
+from __future__ import annotations
+
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.sql import ClauseElement
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import ColumnElement
+from sqlalchemy.sql import ColumnExpressionArgument
+from sqlalchemy.sql import roles
+from sqlalchemy.sql import Select
+from sqlalchemy.sql import SyntaxExtension
+from sqlalchemy.sql import visitors
+
+
+def qualify(predicate: ColumnExpressionArgument[bool]) -> Qualify:
+ """Return a QUALIFY construct
+
+ E.g.::
+
+ stmt = select(qt_table).ext(
+ qualify(func.row_number().over(order_by=qt_table.c.o))
+ )
+
+ """
+ return Qualify(predicate)
+
+
+class Qualify(SyntaxExtension, ClauseElement):
+ """Define the QUALIFY class."""
+
+ predicate: ColumnElement[bool]
+ """A single column expression that is the predicate within the QUALIFY."""
+
+ _traverse_internals = [
+ ("predicate", visitors.InternalTraversal.dp_clauseelement)
+ ]
+ """This structure defines how SQLAlchemy can do a deep traverse of internal
+ contents of this structure. This is mostly used for cache key generation.
+ If the traversal is not written yet, the ``inherit_cache=False`` class
+ level attribute may be used to skip caching for the construct.
+ """
+
+ def __init__(self, predicate: ColumnExpressionArgument):
+ self.predicate = coercions.expect(
+ roles.WhereHavingRole, predicate, apply_propagate_attrs=self
+ )
+
+ def apply_to_select(self, select_stmt: Select) -> None:
+ """Called when the :meth:`.Select.ext` method is called.
+
+ The extension should apply itself to the :class:`.Select`, typically
+ using :meth:`.HasStatementExtensions.apply_syntax_extension_point`,
+ which receives a callable that receives a list of current elements to
+ be concatenated together and then returns a new list of elements to be
+ concatenated together in the final structure. The
+ :meth:`.SyntaxExtension.append_replacing_same_type` callable is
+ usually used for this.
+
+ """
+ select_stmt.apply_syntax_extension_point(
+ self.append_replacing_same_type, "post_criteria"
+ )
+
+
+@compiles(Qualify)
+def _compile_qualify(element, compiler, **kw):
+ """a compiles extension that delivers the SQL text for Qualify"""
+ return f"QUALIFY {compiler.process(element.predicate, **kw)}"
--- /dev/null
+import random
+import unittest
+
+from sqlalchemy import Column
+from sqlalchemy import func
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import select
+from sqlalchemy import Table
+from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from .qualify import qualify
+
+qt_table = Table(
+ "qt",
+ MetaData(),
+ Column("i", Integer),
+ Column("p", Integer),
+ Column("o", Integer),
+)
+
+
+class QualifyCompileTest(AssertsCompiledSQL, fixtures.CacheKeySuite):
+ """A sample test suite for the QUALIFY clause, making use of SQLAlchemy
+ testing utilities.
+
+ """
+
+ __dialect__ = "default"
+
+ @fixtures.CacheKeySuite.run_suite_tests
+ def test_qualify_cache_key(self):
+ """A cache key suite using the ``CacheKeySuite.run_suite_tests``
+ decorator.
+
+ This suite intends to test that the "_traverse_internals" structure
+ of the custom SQL construct covers all the structural elements of
+ the object. A decorated function should return a callable (e.g.
+ a lambda) which returns a list of SQL structures. The suite will
+ call upon this lambda multiple times, to make the same list of
+ SQL structures repeatedly. It then runs comparisons of the generated
+ cache key for each element in a particular list to all the other
+ elements in that same list, as well as other versions of the list.
+
+ The rules for this list are then as follows:
+
+ * Each element of the list should store a SQL structure that is
+ **structurally identical** each time, for a given position in the
+ list. Successive versions of this SQL structure will be compared
+ to previous ones in the same list position and they must be
+ identical.
+
+ * Each element of the list should store a SQL structure that is
+ **structurally different** from **all other** elements in the list.
+ Successive versions of this SQL structure will be compared to
+ other members in other list positions, and they must be different
+ each time.
+
+ * The SQL structures returned in the list should exercise all of the
+ structural features that are provided by the construct. This is
+ to ensure that two different structural elements generate a
+ different cache key and won't be mis-cached.
+
+ * Literal parameters like strings and numbers are **not** part of the
+ cache key itself since these are not "structural" elements; two
+ SQL structures that are identical can nonetheless have different
+ parameterized values. To better exercise testing that this variation
+ is not stored as part of the cache key, ``random`` functions like
+ ``random.randint()`` or ``random.choice()`` can be used to generate
+ random literal values within a single element.
+
+
+ """
+
+ def stmt0():
+ return select(qt_table)
+
+ def stmt1():
+ stmt = stmt0()
+
+ return stmt.ext(qualify(qt_table.c.p == random.choice([2, 6, 10])))
+
+ def stmt2():
+ stmt = stmt0()
+
+ return stmt.ext(
+ qualify(func.row_number().over(order_by=qt_table.c.o))
+ )
+
+ def stmt3():
+ stmt = stmt0()
+
+ return stmt.ext(
+ qualify(
+ func.row_number().over(
+ partition_by=qt_table.c.i, order_by=qt_table.c.o
+ )
+ )
+ )
+
+ return lambda: [stmt0(), stmt1(), stmt2(), stmt3()]
+
+ def test_query_one(self):
+ """A compilation test. This makes use of the
+ ``AssertsCompiledSQL.assert_compile()`` utility.
+
+ """
+
+ stmt = select(qt_table).ext(
+ qualify(
+ func.row_number().over(
+ partition_by=qt_table.c.p, order_by=qt_table.c.o
+ )
+ == 1
+ )
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT qt.i, qt.p, qt.o FROM qt QUALIFY row_number() "
+ "OVER (PARTITION BY qt.p ORDER BY qt.o) = :param_1",
+ )
+
+ def test_query_two(self):
+ """A compilation test. This makes use of the
+ ``AssertsCompiledSQL.assert_compile()`` utility.
+
+ """
+
+ row_num = (
+ func.row_number()
+ .over(partition_by=qt_table.c.p, order_by=qt_table.c.o)
+ .label("row_num")
+ )
+ stmt = select(qt_table, row_num).ext(
+ qualify(row_num.as_reference() == 1)
+ )
+
+ self.assert_compile(
+ stmt,
+ "SELECT qt.i, qt.p, qt.o, row_number() OVER "
+ "(PARTITION BY qt.p ORDER BY qt.o) AS row_num "
+ "FROM qt QUALIFY row_num = :param_1",
+ )
+
+ def test_propagate_attrs(self):
+ """ORM propagate test. this is an optional test that tests
+ apply_propagate_attrs, indicating when you pass ORM classes /
+ attributes to your construct, there's a dictionary called
+ ``._propagate_attrs`` that gets carried along to the statement,
+ which marks it as an "ORM" statement.
+
+ """
+ row_num = (
+ func.row_number().over(partition_by=qt_table.c.p).label("row_num")
+ )
+ row_num._propagate_attrs = {"foo": "bar"}
+
+ stmt = select(1).ext(qualify(row_num.as_reference() == 1))
+
+ eq_(stmt._propagate_attrs, {"foo": "bar"})
+
+
+class QualifyCompileUnittest(QualifyCompileTest, unittest.TestCase):
+ pass
+
+
+if __name__ == "__main__":
+ unittest.main()
passed to with_polymorphic (which is completely unnecessary in modern
use).
+ TODO: What is a "quasi-legacy" case? Do we need this method with
+ 2.0 style select() queries or not? Why is with_polymorphic referring
+ to an alias or subquery "legacy" ?
+
"""
if (
not ext_info.is_aliased_class
if opt._is_compile_state:
opt.process_compile_state(self)
- if statement_container._with_context_options:
- for fn, key in statement_container._with_context_options:
+ if statement_container._compile_state_funcs:
+ for fn, key in statement_container._compile_state_funcs:
fn(self)
self.primary_columns = []
# after it's been set up above
# self._dump_option_struct()
- if select_statement._with_context_options:
- for fn, key in select_statement._with_context_options:
+ if select_statement._compile_state_funcs:
+ for fn, key in select_statement._compile_state_funcs:
fn(self)
self.primary_columns = []
self.distinct = query._distinct
+ self.syntax_extensions = {
+ key: current_adapter(value, True) if current_adapter else value
+ for key, value in query._get_syntax_extensions_as_dict().items()
+ }
+
if query._correlate:
# ORM mapped entities that are mapped to joins can be passed
# to .correlate, so here they are broken into their component
stmt.__dict__.update(
_with_options=statement._with_options,
- _with_context_options=statement._with_context_options,
+ _compile_state_funcs=statement._compile_state_funcs,
_execution_options=statement._execution_options,
_propagate_attrs=statement._propagate_attrs,
)
group_by,
independent_ctes,
independent_ctes_opts,
+ syntax_extensions,
):
statement = Select._create_raw_select(
_raw_columns=raw_columns,
statement._fetch_clause_options = fetch_clause_options
statement._independent_ctes = independent_ctes
statement._independent_ctes_opts = independent_ctes_opts
+ if syntax_extensions:
+ statement._set_syntax_extensions(**syntax_extensions)
if prefixes:
statement._prefixes = prefixes
"independent_ctes_opts": (
self.select_statement._independent_ctes_opts
),
+ "syntax_extensions": self.syntax_extensions,
}
@property
from ..sql._typing import _TypedColumnClauseArgument as _TCCA
from ..sql.base import CacheableOptions
from ..sql.base import ExecutableOption
+ from ..sql.base import SyntaxExtension
from ..sql.dml import UpdateBase
from ..sql.elements import ColumnElement
from ..sql.elements import Label
_memoized_select_entities = ()
+ _syntax_extensions: Tuple[SyntaxExtension, ...] = ()
+
_compile_options: Union[Type[CacheableOptions], CacheableOptions] = (
_ORMCompileState.default_compile_options
)
stmt = FromStatement(self._raw_columns, self._statement)
stmt.__dict__.update(
_with_options=self._with_options,
- _with_context_options=self._with_context_options,
+ _with_context_options=self._compile_state_funcs,
_compile_options=compile_options,
_execution_options=self._execution_options,
_propagate_attrs=self._propagate_attrs,
else:
# Query / select() internal attributes are 99% cross-compatible
stmt = Select._create_raw_select(**self.__dict__)
+
stmt.__dict__.update(
_label_style=self._label_style,
_compile_options=compile_options,
_propagate_attrs=self._propagate_attrs,
)
+ for ext in self._syntax_extensions:
+ stmt._apply_syntax_extension_to_self(ext)
stmt.__dict__.pop("session", None)
# ensure the ORM context is used to compile the statement, even
"_having_criteria",
"_prefixes",
"_suffixes",
+ "_syntax_extensions",
):
self.__dict__.pop(attr, None)
self._set_select_from([fromclause], set_entity_from)
self._distinct = True
return self
+ @_generative
+ def ext(self, extension: SyntaxExtension) -> Self:
+ """Applies a SQL syntax extension to this statement.
+
+ .. seealso::
+
+ :ref:`examples_syntax_extensions`
+
+ .. versionadded:: 2.1
+
+ """
+
+ extension = coercions.expect(roles.SyntaxExtensionRole, extension)
+ self._syntax_extensions += (extension,)
+ return self
+
def all(self) -> List[_T]:
"""Return the results represented by this :class:`_query.Query`
as a list.
delete_ = delete_.with_dialect_options(**delete_args)
delete_._where_criteria = self._where_criteria
+
+ for ext in self._syntax_extensions:
+ delete_._apply_syntax_extension_to_self(ext)
+
result: CursorResult[Any] = self.session.execute(
delete_,
self._params,
upd = upd.with_dialect_options(**update_args)
upd._where_criteria = self._where_criteria
+
+ for ext in self._syntax_extensions:
+ upd._apply_syntax_extension_to_self(ext)
+
result: CursorResult[Any] = self.session.execute(
upd,
self._params,
]
).lazyload(rev).process_compile_state(compile_context)
- stmt._with_context_options += (
- (_lazyload_reverse, self.parent_property),
+ stmt = stmt._add_compile_state_func(
+ _lazyload_reverse, self.parent_property
)
lazy_clause, params = self._generate_lazy_clause(state, passive)
util.to_list(self.parent_property.order_by)
)
- q = q._add_context_option(
+ q = q._add_compile_state_func(
_setup_outermost_orderby, self.parent_property
)
util.to_list(self.parent_property.order_by)
)
- q = q._add_context_option(
+ q = q._add_compile_state_func(
_setup_outermost_orderby, self.parent_property
)
from ._typing import NotNullable as NotNullable
from ._typing import Nullable as Nullable
from .base import Executable as Executable
+from .base import SyntaxExtension as SyntaxExtension
from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS
from .compiler import FROM_LINTING as FROM_LINTING
from .compiler import NO_LINTING as NO_LINTING
from ..util import hybridmethod
from ..util.typing import Self
from ..util.typing import TypeGuard
+from ..util.typing import TypeVarTuple
+from ..util.typing import Unpack
if TYPE_CHECKING:
from . import coercions
from ._orm_types import SynchronizeSessionArgument
from ._typing import _CLE
from .compiler import SQLCompiler
+ from .dml import Delete
+ from .dml import Insert
+ from .dml import Update
from .elements import BindParameter
+ from .elements import ClauseElement
from .elements import ClauseList
from .elements import ColumnClause # noqa
from .elements import ColumnElement
from .selectable import _JoinTargetElement
from .selectable import _SelectIterable
from .selectable import FromClause
+ from .selectable import Select
from ..engine import Connection
from ..engine import CursorResult
from ..engine.interfaces import _CoreMultiExecuteParams
type_api = None # noqa
+_Ts = TypeVarTuple("_Ts")
+
+
class _NoArg(Enum):
NO_ARG = 0
return c
+_L = TypeVar("_L", bound=str)
+
+
+class HasSyntaxExtensions(Generic[_L]):
+
+ _position_map: Mapping[_L, str]
+
+ @_generative
+ def ext(self, extension: SyntaxExtension) -> Self:
+ """Applies a SQL syntax extension to this statement.
+
+ SQL syntax extensions are :class:`.ClauseElement` objects that define
+ some vendor-specific syntactical construct that take place in specific
+ parts of a SQL statement. Examples include vendor extensions like
+ PostgreSQL / SQLite's "ON DUPLICATE KEY UPDATE", PostgreSQL's
+ "DISTINCT ON", and MySQL's "LIMIT" that can be applied to UPDATE
+ and DELETE statements.
+
+ .. seealso::
+
+ :ref:`examples_syntax_extensions`
+
+ .. versionadded:: 2.1
+
+ """
+ extension = coercions.expect(
+ roles.SyntaxExtensionRole, extension, apply_propagate_attrs=self
+ )
+ self._apply_syntax_extension_to_self(extension)
+ return self
+
+ @util.preload_module("sqlalchemy.sql.elements")
+ def apply_syntax_extension_point(
+ self,
+ apply_fn: Callable[[Sequence[ClauseElement]], Sequence[ClauseElement]],
+ position: _L,
+ ) -> None:
+ """Apply a :class:`.SyntaxExtension` to a known extension point.
+
+ Should be used only internally by :class:`.SyntaxExtension`.
+
+ E.g.::
+
+ class Qualify(SyntaxExtension, ClauseElement):
+
+ # ...
+
+ def apply_to_select(self, select_stmt: Select) -> None:
+ # append self to existing
+ select_stmt.apply_extension_point(
+ lambda existing: [*existing, self], "post_criteria"
+ )
+
+
+ class ReplaceExt(SyntaxExtension, ClauseElement):
+
+ # ...
+
+ def apply_to_select(self, select_stmt: Select) -> None:
+ # replace any existing elements regardless of type
+ select_stmt.apply_extension_point(
+ lambda existing: [self], "post_criteria"
+ )
+
+
+ class ReplaceOfTypeExt(SyntaxExtension, ClauseElement):
+
+ # ...
+
+ def apply_to_select(self, select_stmt: Select) -> None:
+ # replace any existing elements of the same type
+ select_stmt.apply_extension_point(
+ self.append_replacing_same_type, "post_criteria"
+ )
+
+ :param apply_fn: callable function that will receive a sequence of
+ :class:`.ClauseElement` that is already populating the extension
+ point (the sequence is empty if there isn't one), and should return
+ a new sequence of :class:`.ClauseElement` that will newly populate
+ that point. The function typically can choose to concatenate the
+ existing values with the new one, or to replace the values that are
+ there with a new one by returning a list of a single element, or
+ to perform more complex operations like removing only the same
+ type element from the input list of merging already existing elements
+ of the same type. Some examples are shown in the examples above
+ :param position: string name of the position to apply to. This
+ varies per statement type. IDEs should show the possible values
+ for each statement type as it's typed with a ``typing.Literal`` per
+ statement.
+
+ .. seealso::
+
+ :ref:`examples_syntax_extensions`
+
+
+ """ # noqa: E501
+
+ try:
+ attrname = self._position_map[position]
+ except KeyError as ke:
+ raise ValueError(
+ f"Unknown position {position!r} for {self.__class__} "
+ f"construct; known positions: "
+ f"{', '.join(repr(k) for k in self._position_map)}"
+ ) from ke
+ else:
+ ElementList = util.preloaded.sql_elements.ElementList
+ existing: Optional[ClauseElement] = getattr(self, attrname, None)
+ if existing is None:
+ input_seq: Tuple[ClauseElement, ...] = ()
+ elif isinstance(existing, ElementList):
+ input_seq = existing.clauses
+ else:
+ input_seq = (existing,)
+
+ new_seq = apply_fn(input_seq)
+ assert new_seq, "cannot return empty sequence"
+ new = new_seq[0] if len(new_seq) == 1 else ElementList(new_seq)
+ setattr(self, attrname, new)
+
+ def _apply_syntax_extension_to_self(
+ self, extension: SyntaxExtension
+ ) -> None:
+ raise NotImplementedError()
+
+ def _get_syntax_extensions_as_dict(self) -> Mapping[_L, SyntaxExtension]:
+ res: Dict[_L, SyntaxExtension] = {}
+ for name, attr in self._position_map.items():
+ value = getattr(self, attr)
+ if value is not None:
+ res[name] = value
+ return res
+
+ def _set_syntax_extensions(self, **extensions: SyntaxExtension) -> None:
+ for name, value in extensions.items():
+ setattr(self, self._position_map[name], value) # type: ignore[index] # noqa: E501
+
+
+class SyntaxExtension(roles.SyntaxExtensionRole):
+ """Defines a unit that when also extending from :class:`.ClauseElement`
+ can be applied to SQLAlchemy statements :class:`.Select`,
+ :class:`_sql.Insert`, :class:`.Update` and :class:`.Delete` making use of
+ pre-established SQL insertion points within these constructs.
+
+ .. versionadded:: 2.1
+
+ .. seealso::
+
+ :ref:`examples_syntax_extensions`
+
+ """
+
+ def append_replacing_same_type(
+ self, existing: Sequence[ClauseElement]
+ ) -> Sequence[ClauseElement]:
+ """Utility function that can be used as
+ :paramref:`_sql.HasSyntaxExtensions.apply_extension_point.apply_fn`
+ to remove any other element of the same type in existing and appending
+ ``self`` to the list.
+
+ This is equivalent to::
+
+ stmt.apply_extension_point(
+ lambda existing: [
+ *(e for e in existing if not isinstance(e, ReplaceOfTypeExt)),
+ self,
+ ],
+ "post_criteria",
+ )
+
+ .. seealso::
+
+ :ref:`examples_syntax_extensions`
+
+ :meth:`_sql.HasSyntaxExtensions.apply_syntax_extension_point`
+
+ """ # noqa: E501
+ cls = type(self)
+ return [*(e for e in existing if not isinstance(e, cls)), self] # type: ignore[list-item] # noqa: E501
+
+ def apply_to_select(self, select_stmt: Select[Unpack[_Ts]]) -> None:
+ """Apply this :class:`.SyntaxExtension` to a :class:`.Select`"""
+ raise NotImplementedError(
+ f"Extension {type(self).__name__} cannot be applied to select"
+ )
+
+ def apply_to_update(self, update_stmt: Update) -> None:
+ """Apply this :class:`.SyntaxExtension` to an :class:`.Update`"""
+ raise NotImplementedError(
+ f"Extension {type(self).__name__} cannot be applied to update"
+ )
+
+ def apply_to_delete(self, delete_stmt: Delete) -> None:
+ """Apply this :class:`.SyntaxExtension` to a :class:`.Delete`"""
+ raise NotImplementedError(
+ f"Extension {type(self).__name__} cannot be applied to delete"
+ )
+
+ def apply_to_insert(self, insert_stmt: Insert) -> None:
+ """Apply this :class:`.SyntaxExtension` to an
+ :class:`_sql.Insert`"""
+ raise NotImplementedError(
+ f"Extension {type(self).__name__} cannot be applied to insert"
+ )
+
+
class Executable(roles.StatementRole):
"""Mark a :class:`_expression.ClauseElement` as supporting execution.
_execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT
_is_default_generator = False
_with_options: Tuple[ExecutableOption, ...] = ()
- _with_context_options: Tuple[
+ _compile_state_funcs: Tuple[
Tuple[Callable[[CompileState], None], Any], ...
] = ()
_compile_options: Optional[Union[Type[CacheableOptions], CacheableOptions]]
_executable_traverse_internals = [
("_with_options", InternalTraversal.dp_executable_options),
(
- "_with_context_options",
- ExtendedInternalTraversal.dp_with_context_options,
+ "_compile_state_funcs",
+ ExtendedInternalTraversal.dp_compile_state_funcs,
),
("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs),
]
"""Apply options to this statement.
In the general sense, options are any kind of Python object
- that can be interpreted by the SQL compiler for the statement.
- These options can be consumed by specific dialects or specific kinds
- of compilers.
-
- The most commonly known kind of option are the ORM level options
- that apply "eager load" and other loading behaviors to an ORM
- query. However, options can theoretically be used for many other
- purposes.
+ that can be interpreted by systems that consume the statement outside
+ of the regular SQL compiler chain. Specifically, these options are
+ the ORM level options that apply "eager load" and other loading
+ behaviors to an ORM query.
For background on specific kinds of options for specific kinds of
statements, refer to the documentation for those option objects.
return self
@_generative
- def _add_context_option(
+ def _add_compile_state_func(
self,
callable_: Callable[[CompileState], None],
cache_args: Any,
) -> Self:
- """Add a context option to this statement.
+ """Add a compile state function to this statement.
- These are callable functions that will
+ When using the ORM only, these are callable functions that will
be given the CompileState object upon compilation.
A second argument cache_args is required, which will be combined with
cache key.
"""
- self._with_context_options += ((callable_, cache_args),)
+ self._compile_state_funcs += ((callable_, cache_args),)
return self
@overload
return repr((sql_str, param_tuple))
def __eq__(self, other: Any) -> bool:
- return bool(self.key == other.key)
+ return other is not None and bool(self.key == other.key)
def __ne__(self, other: Any) -> bool:
- return not (self.key == other.key)
+ return other is None or not (self.key == other.key)
@classmethod
def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str:
visit_propagate_attrs = PROPAGATE_ATTRS
- def visit_with_context_options(
+ def visit_compile_state_funcs(
self,
attrname: str,
obj: Any,
from ._typing import _DDLColumnArgument
from ._typing import _DMLTableArgument
from ._typing import _FromClauseArgument
+ from .base import SyntaxExtension
from .dml import _DMLTableElement
from .elements import BindParameter
from .elements import ClauseElement
) -> Union[ColumnElement[Any], TextClause]: ...
+@overload
+def expect(
+ role: Type[roles.SyntaxExtensionRole],
+ element: Any,
+ **kw: Any,
+) -> SyntaxExtension: ...
+
+
@overload
def expect(
role: Type[roles.LabeledColumnExprRole[Any]],
return _no_text_coercion(element, argname)
+class SyntaxExtensionImpl(RoleImpl):
+ __slots__ = ()
+
+
class StatementOptionImpl(_CoerceLiterals, RoleImpl):
__slots__ = ()
def visit_tuple(self, clauselist, **kw):
return "(%s)" % self.visit_clauselist(clauselist, **kw)
+ def visit_element_list(self, element, **kw):
+ return self._generate_delimited_list(element.clauses, " ", **kw)
+
def visit_clauselist(self, clauselist, **kw):
sep = clauselist.operator
if sep is None:
text = "SELECT " # we're off to a good start !
+ if select_stmt._post_select_clause is not None:
+ psc = self.process(select_stmt._post_select_clause, **kwargs)
+ if psc is not None:
+ text += psc + " "
+
if select_stmt._hints:
hint_text, byfrom = self._setup_select_hints(select_stmt)
if hint_text:
)
text += self.get_select_precolumns(select_stmt, **kwargs)
+
+ if select_stmt._pre_columns_clause is not None:
+ pcc = self.process(select_stmt._pre_columns_clause, **kwargs)
+ if pcc is not None:
+ text += pcc + " "
+
# the actual list of columns to print in the SELECT column list.
inner_columns = [
c
kwargs,
)
+ if select_stmt._post_body_clause is not None:
+ pbc = self.process(select_stmt._post_body_clause, **kwargs)
+ if pbc:
+ text += " " + pbc
+
if select_stmt._statement_hints:
per_dialect = [
ht
if t:
text += " \nHAVING " + t
+ if select._post_criteria_clause is not None:
+ pcc = self.process(select._post_criteria_clause, **kwargs)
+ if pcc is not None:
+ text += " \n" + pcc
+
if select._order_by_clauses:
text += self.order_by_clause(select, **kwargs)
):
"""Provide a hook to override the generation of an
UPDATE..FROM clause.
-
MySQL and MSSQL override this.
-
"""
raise NotImplementedError(
"This backend does not support multiple-table "
if limit_clause:
text += " " + limit_clause
+ if update_stmt._post_criteria_clause is not None:
+ ulc = self.process(
+ update_stmt._post_criteria_clause,
+ from_linter=from_linter,
+ **kw,
+ )
+
+ if ulc:
+ text += " " + ulc
+
if (
self.implicit_returning or update_stmt._returning
) and not self.returning_precedes_values:
if limit_clause:
text += " " + limit_clause
+ if delete_stmt._post_criteria_clause is not None:
+ dlc = self.process(
+ delete_stmt._post_criteria_clause,
+ from_linter=from_linter,
+ **kw,
+ )
+ if dlc:
+ text += " " + dlc
+
if (
self.implicit_returning or delete_stmt._returning
) and not self.returning_precedes_values:
from typing import Dict
from typing import Iterable
from typing import List
+from typing import Literal
from typing import MutableMapping
from typing import NoReturn
from typing import Optional
from .base import Executable
from .base import Generative
from .base import HasCompileState
+from .base import HasSyntaxExtensions
+from .base import SyntaxExtension
from .elements import BooleanClauseList
from .elements import ClauseElement
from .elements import ColumnClause
"""SELECT statement for INSERT .. FROM SELECT"""
_post_values_clause: Optional[ClauseElement] = None
- """used by extensions to Insert etc. to add additional syntacitcal
+ """used by extensions to Insert etc. to add additional syntactical
constructs, e.g. ON CONFLICT etc."""
_values: Optional[util.immutabledict[_DMLColumnElement, Any]] = None
return self
-class Insert(ValuesBase):
+class Insert(ValuesBase, HasSyntaxExtensions[Literal["post_values"]]):
"""Represent an INSERT construct.
The :class:`_expression.Insert` object is created using the
:func:`_expression.insert()` function.
+ Available extension points:
+
+ * ``post_values``: applies additional logic after the ``VALUES`` clause.
+
"""
__visit_name__ = "insert"
+ HasCTE._has_ctes_traverse_internals
)
+ _position_map = util.immutabledict(
+ {
+ "post_values": "_post_values_clause",
+ }
+ )
+
+ _post_values_clause: Optional[ClauseElement] = None
+ """extension point for a ClauseElement that will be compiled directly
+ after the VALUES portion of the :class:`.Insert` statement
+
+ """
+
def __init__(self, table: _DMLTableArgument):
super().__init__(table)
+ def _apply_syntax_extension_to_self(
+ self, extension: SyntaxExtension
+ ) -> None:
+ extension.apply_to_insert(self)
+
@_generative
def inline(self) -> Self:
"""Make this :class:`_expression.Insert` construct "inline" .
"""
+# note: if not for MRO issues, this class should extend
+# from HasSyntaxExtensions[Literal["post_criteria"]]
class DMLWhereBase:
table: _DMLTableElement
_where_criteria: Tuple[ColumnElement[Any], ...] = ()
+ _post_criteria_clause: Optional[ClauseElement] = None
+ """used by extensions to Update/Delete etc. to add additional syntacitcal
+ constructs, e.g. LIMIT etc.
+
+ .. versionadded:: 2.1
+
+ """
+
+ # can't put position_map here either without HasSyntaxExtensions
+ # _position_map = util.immutabledict(
+ # {"post_criteria": "_post_criteria_clause"}
+ # )
+
@_generative
def where(self, *whereclause: _ColumnExpressionArgument[bool]) -> Self:
"""Return a new construct with the given expression(s) added to
)
-class Update(DMLWhereBase, ValuesBase):
+class Update(
+ DMLWhereBase, ValuesBase, HasSyntaxExtensions[Literal["post_criteria"]]
+):
"""Represent an Update construct.
The :class:`_expression.Update` object is created using the
:func:`_expression.update()` function.
+ Available extension points:
+
+ * ``post_criteria``: applies additional logic after the ``WHERE`` clause.
+
"""
__visit_name__ = "update"
("_returning", InternalTraversal.dp_clauseelement_tuple),
("_hints", InternalTraversal.dp_table_hint_list),
("_return_defaults", InternalTraversal.dp_boolean),
+ ("_post_criteria_clause", InternalTraversal.dp_clauseelement),
(
"_return_defaults_columns",
InternalTraversal.dp_clauseelement_tuple,
+ HasCTE._has_ctes_traverse_internals
)
+ _position_map = util.immutabledict(
+ {"post_criteria": "_post_criteria_clause"}
+ )
+
def __init__(self, table: _DMLTableArgument):
super().__init__(table)
self._inline = True
return self
+ def _apply_syntax_extension_to_self(
+ self, extension: SyntaxExtension
+ ) -> None:
+ extension.apply_to_update(self)
+
if TYPE_CHECKING:
# START OVERLOADED FUNCTIONS self.returning ReturningUpdate 1-8
"""
-class Delete(DMLWhereBase, UpdateBase):
+class Delete(
+ DMLWhereBase, UpdateBase, HasSyntaxExtensions[Literal["post_criteria"]]
+):
"""Represent a DELETE construct.
The :class:`_expression.Delete` object is created using the
:func:`_expression.delete()` function.
+ Available extension points:
+
+ * ``post_criteria``: applies additional logic after the ``WHERE`` clause.
+
"""
__visit_name__ = "delete"
("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
("_returning", InternalTraversal.dp_clauseelement_tuple),
("_hints", InternalTraversal.dp_table_hint_list),
+ ("_post_criteria_clause", InternalTraversal.dp_clauseelement),
]
+ HasPrefixes._has_prefixes_traverse_internals
+ DialectKWArgs._dialect_kwargs_traverse_internals
+ HasCTE._has_ctes_traverse_internals
)
+ _position_map = util.immutabledict(
+ {"post_criteria": "_post_criteria_clause"}
+ )
+
def __init__(self, table: _DMLTableArgument):
self.table = coercions.expect(
roles.DMLTableRole, table, apply_propagate_attrs=self
)
+ def _apply_syntax_extension_to_self(
+ self, extension: SyntaxExtension
+ ) -> None:
+ extension.apply_to_delete(self)
+
if TYPE_CHECKING:
# START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8
True_._create_singleton()
+class ElementList(DQLDMLClauseElement):
+ """Describe a list of clauses that will be space separated.
+
+ This is a minimal version of :class:`.ClauseList` which is used by
+ the :class:`.HasSyntaxExtension` class. It does not do any coercions
+ so should be used internally only.
+
+ .. versionadded:: 2.1
+
+ """
+
+ __visit_name__ = "element_list"
+
+ _traverse_internals: _TraverseInternalsType = [
+ ("clauses", InternalTraversal.dp_clauseelement_tuple),
+ ]
+
+ clauses: typing_Tuple[ClauseElement, ...]
+
+ def __init__(self, clauses: Sequence[ClauseElement]):
+ self.clauses = tuple(clauses)
+
+
class ClauseList(
roles.InElementRole,
roles.OrderByRole,
def __init__(self, element: ColumnElement[_T]):
self.element = element
+ self._propagate_attrs = element._propagate_attrs
@util.ro_non_memoized_property
def _from_objects(self) -> List[FromClause]:
def _order_by_label_element(self):
return self
+ def as_reference(self) -> _label_reference[_T]:
+ """refer to this labeled expression in a clause such as GROUP BY,
+ ORDER BY etc. as the label name itself, without expanding
+ into the full expression.
+
+ .. versionadded:: 2.1
+
+ """
+ return _label_reference(self)
+
@HasMemoized.memoized_attribute
def element(self) -> ColumnElement[_T]:
return self._element.self_group(against=operators.as_)
uses_inspection = False
+class SyntaxExtensionRole(SQLRole):
+ __slots__ = ()
+ _role_name = "Syntax extension construct"
+
+
class UsesInspection:
__slots__ = ()
_post_inspect: Literal[None] = None
from .base import Generative
from .base import HasCompileState
from .base import HasMemoized
+from .base import HasSyntaxExtensions
from .base import Immutable
+from .base import SyntaxExtension
from .coercions import _document_text_coercion
from .elements import _anonymous_label
from .elements import BindParameter
HasSuffixes,
HasHints,
HasCompileState,
+ HasSyntaxExtensions[
+ Literal["post_select", "pre_columns", "post_criteria", "post_body"]
+ ],
_SelectFromElements,
GenerativeSelect,
TypedReturnsRows[Unpack[_Ts]],
The :class:`_sql.Select` object is normally constructed using the
:func:`_sql.select` function. See that function for details.
+ Available extension points:
+
+ * ``post_select``: applies additional logic after the ``SELECT`` keyword.
+ * ``pre_columns``: applies additional logic between the ``DISTINCT``
+ keyword (if any) and the list of columns.
+ * ``post_criteria``: applies additional logic after the ``HAVING`` clause.
+ * ``post_body``: applies additional logic after the ``FOR UPDATE`` clause.
+
.. seealso::
:func:`_sql.select`
_where_criteria: Tuple[ColumnElement[Any], ...] = ()
_having_criteria: Tuple[ColumnElement[Any], ...] = ()
_from_obj: Tuple[FromClause, ...] = ()
+
+ _position_map = util.immutabledict(
+ {
+ "post_select": "_post_select_clause",
+ "pre_columns": "_pre_columns_clause",
+ "post_criteria": "_post_criteria_clause",
+ "post_body": "_post_body_clause",
+ }
+ )
+
+ _post_select_clause: Optional[ClauseElement] = None
+ """extension point for a ClauseElement that will be compiled directly
+ after the SELECT keyword.
+
+ .. versionadded:: 2.1
+
+ """
+
+ _pre_columns_clause: Optional[ClauseElement] = None
+ """extension point for a ClauseElement that will be compiled directly
+ before the "columns" clause; after DISTINCT (if present).
+
+ .. versionadded:: 2.1
+
+ """
+
+ _post_criteria_clause: Optional[ClauseElement] = None
+ """extension point for a ClauseElement that will be compiled directly
+ after "criteria", following the HAVING clause but before ORDER BY.
+
+ .. versionadded:: 2.1
+
+ """
+
+ _post_body_clause: Optional[ClauseElement] = None
+ """extension point for a ClauseElement that will be compiled directly
+ after the "body", following the ORDER BY, LIMIT, and FOR UPDATE sections
+ of the SELECT.
+
+ .. versionadded:: 2.1
+
+ """
+
_auto_correlate = True
_is_select_statement = True
_compile_options: CacheableOptions = (
("_distinct", InternalTraversal.dp_boolean),
("_distinct_on", InternalTraversal.dp_clauseelement_tuple),
("_label_style", InternalTraversal.dp_plain_obj),
+ ("_post_select_clause", InternalTraversal.dp_clauseelement),
+ ("_pre_columns_clause", InternalTraversal.dp_clauseelement),
+ ("_post_criteria_clause", InternalTraversal.dp_clauseelement),
+ ("_post_body_clause", InternalTraversal.dp_clauseelement),
]
+ HasCTE._has_ctes_traverse_internals
+ HasPrefixes._has_prefixes_traverse_internals
GenerativeSelect.__init__(self)
+ def _apply_syntax_extension_to_self(
+ self, extension: SyntaxExtension
+ ) -> None:
+ extension.apply_to_select(self)
+
def _scalar_type(self) -> TypeEngine[Any]:
if not self._raw_columns:
return NULLTYPE
for l, r in zip_longest(ltup, rtup, fillvalue=None):
self.stack.append((l, r))
+ def visit_multi_list(
+ self, attrname, left_parent, left, right_parent, right, **kw
+ ):
+ for l, r in zip_longest(left, right, fillvalue=None):
+ if isinstance(l, str):
+ if not isinstance(r, str) or l != r:
+ return COMPARE_FAILED
+ elif isinstance(r, str):
+ if not isinstance(l, str) or l != r:
+ return COMPARE_FAILED
+ else:
+ self.stack.append((l, r))
+
def visit_clauseelement_list(
self, attrname, left_parent, left, right_parent, right, **kw
):
else:
return left == right
- def visit_with_context_options(
+ def visit_compile_state_funcs(
self, attrname, left_parent, left, right_parent, right, **kw
):
return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
dp_executable_options = "EO"
- dp_with_context_options = "WC"
+ dp_compile_state_funcs = "WC"
dp_fromclause_ordered_set = "CO"
"""Visit an ordered set of :class:`_expression.FromClause` objects. """
stop_test_class_inside_fixtures as stop_test_class_inside_fixtures,
)
from .sql import CacheKeyFixture as CacheKeyFixture
+from .sql import CacheKeySuite as CacheKeySuite
from .sql import (
ComputedReflectionFixtureTest as ComputedReflectionFixtureTest,
)
class CacheKeyFixture:
- def _compare_equal(self, a, b, compare_values):
+ def _compare_equal(self, a, b, *, compare_values=False):
a_key = a._generate_cache_key()
b_key = b._generate_cache_key()
if a_key is None:
- assert a._annotations.get("nocache")
+ assert a._annotations.get("nocache"), (
+ "Construct doesn't cache, so test suite should "
+ "add the 'nocache' annotation"
+ )
assert b_key is None
else:
assert a_param.compare(b_param, compare_values=compare_values)
return a_key, b_key
- def _run_cache_key_fixture(self, fixture, compare_values):
+ def _run_compare_fixture(self, fixture, *, compare_values=False):
+ case_a = fixture()
+ case_b = fixture()
+
+ for a, b in itertools.combinations_with_replacement(
+ range(len(case_a)), 2
+ ):
+ if a == b:
+ assert case_a[a].compare(
+ case_b[b], compare_values=compare_values
+ )
+ else:
+ assert not case_a[a].compare(
+ case_b[b], compare_values=compare_values
+ )
+
+ def _run_cache_key_fixture(self, fixture, *, compare_values=False):
case_a = fixture()
case_b = fixture()
):
if a == b:
a_key, b_key = self._compare_equal(
- case_a[a], case_b[b], compare_values
+ case_a[a], case_b[b], compare_values=compare_values
)
if a_key is None:
continue
for a, b in itertools.combinations_with_replacement(
range(len(case_a)), 2
):
- self._compare_equal(case_a[a], case_b[b], compare_values)
+ self._compare_equal(
+ case_a[a], case_b[b], compare_values=compare_values
+ )
+
+
+class CacheKeySuite(CacheKeyFixture):
+ @classmethod
+ def run_suite_tests(cls, fn):
+ def decorate(self):
+ self._run_cache_key_fixture(fn(self), compare_values=False)
+ self._run_compare_fixture(fn(self), compare_values=False)
+
+ decorate.__name__ = fn.__name__
+ return decorate
def insertmanyvalues_fixture(
fixtures.TestBase,
):
pass
+
+
+test_qualify = __import__(
+ "examples.syntax_extensions.test_qualify"
+).syntax_extensions.test_qualify
+
+
+class QualifyCompileTest(test_qualify.QualifyCompileTest, fixtures.TestBase):
+ pass
--- /dev/null
+from __future__ import annotations
+
+from typing import Any
+
+from sqlalchemy import insert
+from sqlalchemy import inspect
+from sqlalchemy import select
+from sqlalchemy import testing
+from sqlalchemy import update
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.sql import ClauseElement
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import roles
+from sqlalchemy.sql._typing import _ColumnExpressionArgument
+from sqlalchemy.sql.base import SyntaxExtension
+from sqlalchemy.sql.dml import Delete
+from sqlalchemy.sql.dml import Update
+from sqlalchemy.sql.visitors import _TraverseInternalsType
+from sqlalchemy.sql.visitors import InternalTraversal
+from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import eq_
+from .test_query import QueryTest
+
+
+class PostSelectClause(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [*existing, self],
+ "post_select",
+ )
+
+
+class PreColumnsClause(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [*existing, self],
+ "pre_columns",
+ )
+
+
+class PostCriteriaClause(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [*existing, self],
+ "post_criteria",
+ )
+
+ def apply_to_update(self, update_stmt: Update) -> None:
+ update_stmt.apply_syntax_extension_point(
+ lambda existing: [self], "post_criteria"
+ )
+
+ def apply_to_delete(self, delete_stmt: Delete) -> None:
+ delete_stmt.apply_syntax_extension_point(
+ lambda existing: [self], "post_criteria"
+ )
+
+
+class PostCriteriaClause2(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [*existing, self],
+ "post_criteria",
+ )
+
+
+class PostCriteriaClauseCols(PostCriteriaClause):
+ _traverse_internals: _TraverseInternalsType = [
+ ("exprs", InternalTraversal.dp_clauseelement_tuple),
+ ]
+
+ def __init__(self, *exprs: _ColumnExpressionArgument[Any]):
+ self.exprs = tuple(
+ coercions.expect(roles.ByOfRole, e, apply_propagate_attrs=self)
+ for e in exprs
+ )
+
+
+class PostCriteriaClauseColsNoProp(PostCriteriaClause):
+ _traverse_internals: _TraverseInternalsType = [
+ ("exprs", InternalTraversal.dp_clauseelement_tuple),
+ ]
+
+ def __init__(self, *exprs: _ColumnExpressionArgument[Any]):
+ self.exprs = tuple(coercions.expect(roles.ByOfRole, e) for e in exprs)
+
+
+class PostBodyClause(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [self],
+ "post_body",
+ )
+
+
+class PostValuesClause(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_insert(self, insert_stmt):
+ insert_stmt.apply_syntax_extension_point(
+ lambda existing: [self],
+ "post_values",
+ )
+
+
+@compiles(PostSelectClause)
+def _compile_psk(element, compiler, **kw):
+ return "POST SELECT KEYWORD"
+
+
+@compiles(PreColumnsClause)
+def _compile_pcc(element, compiler, **kw):
+ return "PRE COLUMNS"
+
+
+@compiles(PostCriteriaClause)
+def _compile_psc(element, compiler, **kw):
+ return "POST CRITERIA"
+
+
+@compiles(PostCriteriaClause2)
+def _compile_psc2(element, compiler, **kw):
+ return "2 POST CRITERIA 2"
+
+
+@compiles(PostCriteriaClauseCols)
+def _compile_psc_cols(element, compiler, **kw):
+ return f"""PC COLS ({
+ ', '.join(compiler.process(expr, **kw) for expr in element.exprs)
+ })"""
+
+
+@compiles(PostBodyClause)
+def _compile_psb(element, compiler, **kw):
+ return "POST SELECT BODY"
+
+
+@compiles(PostValuesClause)
+def _compile_pvc(element, compiler, **kw):
+ return "POST VALUES"
+
+
+class TestExtensionPoints(QueryTest, AssertsCompiledSQL):
+ __dialect__ = "default"
+
+ def test_select_post_select_clause(self):
+ User = self.classes.User
+
+ stmt = select(User).ext(PostSelectClause()).where(User.name == "x")
+ self.assert_compile(
+ stmt,
+ "SELECT POST SELECT KEYWORD users.id, users.name "
+ "FROM users WHERE users.name = :name_1",
+ )
+
+ def test_select_pre_columns_clause(self):
+ User = self.classes.User
+
+ stmt = select(User).ext(PreColumnsClause()).where(User.name == "x")
+ self.assert_compile(
+ stmt,
+ "SELECT PRE COLUMNS users.id, users.name FROM users "
+ "WHERE users.name = :name_1",
+ )
+
+ def test_select_post_criteria_clause(self):
+ User = self.classes.User
+
+ stmt = select(User).ext(PostCriteriaClause()).where(User.name == "x")
+ self.assert_compile(
+ stmt,
+ "SELECT users.id, users.name FROM users "
+ "WHERE users.name = :name_1 POST CRITERIA",
+ )
+
+ def test_select_post_criteria_clause_multiple(self):
+ User = self.classes.User
+
+ stmt = (
+ select(User)
+ .ext(PostCriteriaClause())
+ .ext(PostCriteriaClause2())
+ .where(User.name == "x")
+ )
+ self.assert_compile(
+ stmt,
+ "SELECT users.id, users.name FROM users "
+ "WHERE users.name = :name_1 POST CRITERIA 2 POST CRITERIA 2",
+ )
+
+ def test_select_post_select_body(self):
+ User = self.classes.User
+
+ stmt = select(User).ext(PostBodyClause()).where(User.name == "x")
+
+ self.assert_compile(
+ stmt,
+ "SELECT users.id, users.name FROM users "
+ "WHERE users.name = :name_1 POST SELECT BODY",
+ )
+
+ def test_insert_post_values(self):
+ User = self.classes.User
+
+ self.assert_compile(
+ insert(User).ext(PostValuesClause()),
+ "INSERT INTO users (id, name) VALUES (:id, :name) POST VALUES",
+ )
+
+ def test_update_post_criteria(self):
+ User = self.classes.User
+
+ self.assert_compile(
+ update(User).ext(PostCriteriaClause()).where(User.name == "hi"),
+ "UPDATE users SET id=:id, name=:name "
+ "WHERE users.name = :name_1 POST CRITERIA",
+ )
+
+ @testing.combinations(
+ (lambda User: select(1).ext(PostCriteriaClauseCols(User.id)), True),
+ (
+ lambda User: select(1).ext(PostCriteriaClauseColsNoProp(User.id)),
+ False,
+ ),
+ (
+ lambda User, users: users.update().ext(
+ PostCriteriaClauseCols(User.id)
+ ),
+ True,
+ ),
+ (
+ lambda User, users: users.delete().ext(
+ PostCriteriaClauseCols(User.id)
+ ),
+ True,
+ ),
+ (lambda User, users: users.delete(), False),
+ )
+ def test_propagate_attrs(self, stmt, expected):
+ User = self.classes.User
+ user_table = self.tables.users
+
+ stmt = testing.resolve_lambda(stmt, User=User, users=user_table)
+
+ if expected:
+ eq_(
+ stmt._propagate_attrs,
+ {
+ "compile_state_plugin": "orm",
+ "plugin_subject": inspect(User),
+ },
+ )
+ else:
+ eq_(stmt._propagate_attrs, {})
from sqlalchemy import cast
from sqlalchemy import Column
from sqlalchemy import column
+from sqlalchemy import DateTime
from sqlalchemy import dialects
from sqlalchemy import exists
from sqlalchemy import extract
from sqlalchemy.sql.annotation import Annotated
from sqlalchemy.sql.base import HasCacheKey
from sqlalchemy.sql.base import SingletonConstant
+from sqlalchemy.sql.base import SyntaxExtension
from sqlalchemy.sql.elements import _label_reference
from sqlalchemy.sql.elements import _textual_label_reference
from sqlalchemy.sql.elements import BindParameter
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import ClauseList
from sqlalchemy.sql.elements import CollationClause
+from sqlalchemy.sql.elements import DQLDMLClauseElement
+from sqlalchemy.sql.elements import ElementList
from sqlalchemy.sql.elements import Immutable
from sqlalchemy.sql.elements import Null
from sqlalchemy.sql.elements import Slice
+from sqlalchemy.sql.elements import TypeClause
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.functions import FunctionElement
from sqlalchemy.sql.functions import GenericFunction
_label_reference(table_a.c.a.desc()),
_label_reference(table_a.c.a.asc()),
),
+ lambda: (
+ TypeClause(String(50)),
+ TypeClause(DateTime()),
+ ),
+ lambda: (
+ table_a.c.a,
+ ElementList([table_a.c.a]),
+ ElementList([table_a.c.a, table_a.c.b]),
+ ),
lambda: (_textual_label_reference("a"), _textual_label_reference("b")),
lambda: (
text("select a, b from table").columns(a=Integer, b=String),
def _statements_w_context_options_fixtures():
return [
- select(table_a)._add_context_option(opt1, True),
- select(table_a)._add_context_option(opt1, 5),
+ select(table_a)._add_compile_state_func(opt1, True),
+ select(table_a)._add_compile_state_func(opt1, 5),
select(table_a)
- ._add_context_option(opt1, True)
- ._add_context_option(opt2, True),
+ ._add_compile_state_func(opt1, True)
+ ._add_compile_state_func(opt2, True),
select(table_a)
- ._add_context_option(opt1, True)
- ._add_context_option(opt2, 5),
- select(table_a)._add_context_option(opt3, True),
+ ._add_compile_state_func(opt1, True)
+ ._add_compile_state_func(opt2, 5),
+ select(table_a)._add_compile_state_func(opt3, True),
]
fixtures.append(_statements_w_context_options_fixtures)
# a typed column expression, so this is fine
return (column("x", Integer).in_(elements),)
- self._run_cache_key_fixture(fixture, False)
+ self._run_cache_key_fixture(fixture, compare_values=False)
def test_cache_key(self):
for fixtures_, compare_values in [
(self.type_cache_key_fixtures, False),
]:
for fixture in fixtures_:
- self._run_cache_key_fixture(fixture, compare_values)
+ self._run_cache_key_fixture(
+ fixture, compare_values=compare_values
+ )
def test_cache_key_equal(self):
for fixture in self.equal_fixtures:
self._run_cache_key_fixture(
fixture,
- True,
+ compare_values=True,
)
def test_bindparam_subclass_nocache(self):
_literal_bindparam(None),
)
- self._run_cache_key_fixture(fixture, True)
+ self._run_cache_key_fixture(fixture, compare_values=True)
def test_cache_key_unknown_traverse(self):
class Foobar1(ClauseElement):
),
"FromStatement": (
{"_raw_columns", "_with_options", "element"}
- | {"_propagate_attrs", "_with_context_options"},
+ | {"_propagate_attrs", "_compile_state_funcs"},
{"element", "entities"},
),
"FunctionAsBinary": (
"_hints",
"_independent_ctes",
"_distinct_on",
- "_with_context_options",
+ "_compile_state_funcs",
"_setup_joins",
"_suffixes",
"_memoized_select_entities",
"_annotations",
"_fetch_clause_options",
"_from_obj",
+ "_post_select_clause",
+ "_post_body_clause",
+ "_post_criteria_clause",
+ "_pre_columns_clause",
},
{"entities"},
),
@testing.combinations(
*all_hascachekey_subclasses(
- ignore_subclasses=[Annotated, NoInit, SingletonConstant]
+ ignore_subclasses=[
+ Annotated,
+ NoInit,
+ SingletonConstant,
+ SyntaxExtension,
+ ]
)
)
def test_init_args_in_traversal(self, cls: type):
if "orm" not in cls.__module__
and "compiler" not in cls.__module__
and "dialects" not in cls.__module__
- and issubclass(cls, (ColumnElement, Selectable, LambdaElement))
+ and issubclass(
+ cls,
+ (
+ ColumnElement,
+ Selectable,
+ LambdaElement,
+ DQLDMLClauseElement,
+ ),
+ )
)
for fixture in self.fixtures + self.dont_compare_values_fixtures:
--- /dev/null
+from sqlalchemy import Column
+from sqlalchemy import column
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import select
+from sqlalchemy import Table
+from sqlalchemy import table
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.sql import ClauseElement
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import roles
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql.base import SyntaxExtension
+from sqlalchemy.sql.dml import Delete
+from sqlalchemy.sql.dml import Update
+from sqlalchemy.sql.visitors import _TraverseInternalsType
+from sqlalchemy.sql.visitors import InternalTraversal
+from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
+
+
+class PostSelectClause(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [*existing, self],
+ "post_select",
+ )
+
+
+class PreColumnsClause(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [*existing, self],
+ "pre_columns",
+ )
+
+
+class PostCriteriaClause(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [*existing, self],
+ "post_criteria",
+ )
+
+ def apply_to_update(self, update_stmt: Update) -> None:
+ update_stmt.apply_syntax_extension_point(
+ lambda existing: [self], "post_criteria"
+ )
+
+ def apply_to_delete(self, delete_stmt: Delete) -> None:
+ delete_stmt.apply_syntax_extension_point(
+ lambda existing: [self], "post_criteria"
+ )
+
+
+class PostCriteriaClause2(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ self.append_replacing_same_type,
+ "post_criteria",
+ )
+
+
+class PostCriteriaClause3(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [self],
+ "post_criteria",
+ )
+
+
+class PostBodyClause(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [self],
+ "post_body",
+ )
+
+
+class PostValuesClause(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_insert(self, insert_stmt):
+ insert_stmt.apply_syntax_extension_point(
+ lambda existing: [self],
+ "post_values",
+ )
+
+
+class ColumnExpressionExt(SyntaxExtension, ClauseElement):
+ _traverse_internals: _TraverseInternalsType = [
+ ("_exprs", InternalTraversal.dp_clauseelement_tuple),
+ ]
+
+ def __init__(self, *exprs):
+ self._exprs = tuple(
+ coercions.expect(roles.ByOfRole, e, apply_propagate_attrs=self)
+ for e in exprs
+ )
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [*existing, self],
+ "post_select",
+ )
+
+
+@compiles(PostSelectClause)
+def _compile_psk(element, compiler, **kw):
+ return "POST SELECT KEYWORD"
+
+
+@compiles(PreColumnsClause)
+def _compile_pcc(element, compiler, **kw):
+ return "PRE COLUMNS"
+
+
+@compiles(PostCriteriaClause)
+def _compile_psc(element, compiler, **kw):
+ return "POST CRITERIA"
+
+
+@compiles(PostCriteriaClause2)
+def _compile_psc2(element, compiler, **kw):
+ return "2 POST CRITERIA 2"
+
+
+@compiles(PostCriteriaClause3)
+def _compile_psc3(element, compiler, **kw):
+ return "3 POST CRITERIA 3"
+
+
+@compiles(PostBodyClause)
+def _compile_psb(element, compiler, **kw):
+ return "POST SELECT BODY"
+
+
+@compiles(PostValuesClause)
+def _compile_pvc(element, compiler, **kw):
+ return "POST VALUES"
+
+
+@compiles(ColumnExpressionExt)
+def _compile_cee(element, compiler, **kw):
+ inner = ", ".join(compiler.process(elem, **kw) for elem in element._exprs)
+ return f"COLUMN EXPRESSIONS ({inner})"
+
+
+class TestExtensionPoints(fixtures.TestBase, AssertsCompiledSQL):
+ __dialect__ = "default"
+
+ def test_illegal_section(self):
+ class SomeExtension(SyntaxExtension, ClauseElement):
+ _traverse_internals = []
+
+ def apply_to_select(self, select_stmt):
+ select_stmt.apply_syntax_extension_point(
+ lambda existing: [self],
+ "not_present",
+ )
+
+ with expect_raises_message(
+ ValueError,
+ r"Unknown position 'not_present' for <class .*Select'> "
+ "construct; known positions: "
+ "'post_select', 'pre_columns', 'post_criteria', 'post_body'",
+ ):
+ select(column("q")).ext(SomeExtension())
+
+ def test_select_post_select_clause(self):
+ self.assert_compile(
+ select(column("a"), column("b"))
+ .ext(PostSelectClause())
+ .where(column("q") == 5),
+ "SELECT POST SELECT KEYWORD a, b WHERE q = :q_1",
+ )
+
+ def test_select_pre_columns_clause(self):
+ self.assert_compile(
+ select(column("a"), column("b"))
+ .ext(PreColumnsClause())
+ .where(column("q") == 5)
+ .distinct(),
+ "SELECT DISTINCT PRE COLUMNS a, b WHERE q = :q_1",
+ )
+
+ def test_select_post_criteria_clause(self):
+ self.assert_compile(
+ select(column("a"), column("b"))
+ .ext(PostCriteriaClause())
+ .where(column("q") == 5)
+ .having(column("z") == 10)
+ .order_by(column("r")),
+ "SELECT a, b WHERE q = :q_1 HAVING z = :z_1 "
+ "POST CRITERIA ORDER BY r",
+ )
+
+ def test_select_post_criteria_clause_multiple(self):
+ self.assert_compile(
+ select(column("a"), column("b"))
+ .ext(PostCriteriaClause())
+ .ext(PostCriteriaClause2())
+ .where(column("q") == 5)
+ .having(column("z") == 10)
+ .order_by(column("r")),
+ "SELECT a, b WHERE q = :q_1 HAVING z = :z_1 "
+ "POST CRITERIA 2 POST CRITERIA 2 ORDER BY r",
+ )
+
+ def test_select_post_criteria_clause_multiple2(self):
+ stmt = (
+ select(column("a"), column("b"))
+ .ext(PostCriteriaClause())
+ .ext(PostCriteriaClause())
+ .ext(PostCriteriaClause2())
+ .ext(PostCriteriaClause2())
+ .where(column("q") == 5)
+ .having(column("z") == 10)
+ .order_by(column("r"))
+ )
+ # PostCriteriaClause2 is here only once
+ self.assert_compile(
+ stmt,
+ "SELECT a, b WHERE q = :q_1 HAVING z = :z_1 "
+ "POST CRITERIA POST CRITERIA 2 POST CRITERIA 2 ORDER BY r",
+ )
+ # now there is only PostCriteriaClause3
+ self.assert_compile(
+ stmt.ext(PostCriteriaClause3()),
+ "SELECT a, b WHERE q = :q_1 HAVING z = :z_1 "
+ "3 POST CRITERIA 3 ORDER BY r",
+ )
+
+ def test_select_post_select_body(self):
+ self.assert_compile(
+ select(column("a"), column("b"))
+ .ext(PostBodyClause())
+ .where(column("q") == 5)
+ .having(column("z") == 10)
+ .order_by(column("r"))
+ .limit(15),
+ "SELECT a, b WHERE q = :q_1 HAVING z = :z_1 "
+ "ORDER BY r LIMIT :param_1 POST SELECT BODY",
+ )
+
+ def test_insert_post_values(self):
+ t = table("t", column("a"), column("b"))
+ self.assert_compile(
+ t.insert().ext(PostValuesClause()),
+ "INSERT INTO t (a, b) VALUES (:a, :b) POST VALUES",
+ )
+
+ def test_update_post_criteria(self):
+ t = table("t", column("a"), column("b"))
+ self.assert_compile(
+ t.update().ext(PostCriteriaClause()).where(t.c.a == "hi"),
+ "UPDATE t SET a=:a, b=:b WHERE t.a = :a_1 POST CRITERIA",
+ )
+
+ def test_delete_post_criteria(self):
+ t = table("t", column("a"), column("b"))
+ self.assert_compile(
+ t.delete().ext(PostCriteriaClause()).where(t.c.a == "hi"),
+ "DELETE FROM t WHERE t.a = :a_1 POST CRITERIA",
+ )
+
+
+class TestExpressionExtensions(
+ fixtures.CacheKeyFixture, fixtures.TestBase, AssertsCompiledSQL
+):
+ __dialect__ = "default"
+
+ def test_render(self):
+ t = Table(
+ "t1", MetaData(), Column("c1", Integer), Column("c2", Integer)
+ )
+
+ stmt = select(t).ext(ColumnExpressionExt(t.c.c1, t.c.c2))
+ self.assert_compile(
+ stmt,
+ "SELECT COLUMN EXPRESSIONS (t1.c1, t1.c2) t1.c1, t1.c2 FROM t1",
+ )
+
+ def test_adaptation(self):
+ t = Table(
+ "t1", MetaData(), Column("c1", Integer), Column("c2", Integer)
+ )
+
+ s1 = select(t).subquery()
+ s2 = select(t).ext(ColumnExpressionExt(t.c.c1, t.c.c2))
+ s3 = sql_util.ClauseAdapter(s1).traverse(s2)
+
+ self.assert_compile(
+ s3,
+ "SELECT COLUMN EXPRESSIONS (anon_1.c1, anon_1.c2) "
+ "anon_1.c1, anon_1.c2 FROM "
+ "(SELECT t1.c1 AS c1, t1.c2 AS c2 FROM t1) AS anon_1",
+ )
+
+ def test_compare(self):
+ t = Table(
+ "t1", MetaData(), Column("c1", Integer), Column("c2", Integer)
+ )
+
+ self._run_compare_fixture(
+ lambda: (
+ select(t).ext(ColumnExpressionExt(t.c.c1, t.c.c2)),
+ select(t).ext(ColumnExpressionExt(t.c.c1)),
+ select(t),
+ )
+ )