]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
syntax extensions (patch 1)
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 16 Dec 2024 22:29:22 +0000 (17:29 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 5 Mar 2025 21:03:16 +0000 (16:03 -0500)
Added the ability to create custom SQL constructs that can define new
clauses within SELECT, INSERT, UPDATE, and DELETE statements without
needing to modify the construction or compilation code of of
:class:`.Select`, :class:`.Insert`, :class:`.Update`, or :class:`.Delete`
directly.  Support for testing these constructs, including caching support,
is present along with an example test suite.  The use case for these
constructs is expected to be third party dialects for NewSQL or other novel
styles of database that introduce new clauses to these statements.   A new
example suite is included which illustrates the ``QUALIFY`` SQL construct
used by several NewSQL databases which includes a cachable implementation
as well as a test suite.

Since these extensions start to make it a bit crowded with how many
kinds of "options" we have on statements, did some naming /
documentation changes with existing constructs on Executable, in
particular to distinguish ExecutableOption from SyntaxExtension.

Fixes: #12195
Change-Id: I4a44ee5bbc3d8b1b640837680c09d25b1b7077af

26 files changed:
doc/build/changelog/unreleased_21/12195.rst [new file with mode: 0644]
doc/build/core/compiler.rst
doc/build/orm/examples.rst
examples/syntax_extensions/__init__.py [new file with mode: 0644]
examples/syntax_extensions/qualify.py [new file with mode: 0644]
examples/syntax_extensions/test_qualify.py [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/cache_key.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/roles.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/testing/fixtures/__init__.py
lib/sqlalchemy/testing/fixtures/sql.py
test/base/test_examples.py
test/orm/test_syntax_extensions.py [new file with mode: 0644]
test/sql/test_compare.py
test/sql/test_syntax_extensions.py [new file with mode: 0644]

diff --git a/doc/build/changelog/unreleased_21/12195.rst b/doc/build/changelog/unreleased_21/12195.rst
new file mode 100644 (file)
index 0000000..a36d1bc
--- /dev/null
@@ -0,0 +1,20 @@
+.. change::
+    :tags: feature, sql
+    :tickets: 12195
+
+    Added the ability to create custom SQL constructs that can define new
+    clauses within SELECT, INSERT, UPDATE, and DELETE statements without
+    needing to modify the construction or compilation code of of
+    :class:`.Select`, :class:`.Insert`, :class:`.Update`, or :class:`.Delete`
+    directly.  Support for testing these constructs, including caching support,
+    is present along with an example test suite.  The use case for these
+    constructs is expected to be third party dialects for analytical SQL
+    (so-called NewSQL) or other novel styles of database that introduce new
+    clauses to these statements.   A new example suite is included which
+    illustrates the ``QUALIFY`` SQL construct used by several NewSQL databases
+    which includes a cachable implementation as well as a test suite.
+
+    .. seealso::
+
+        :ref:`examples.syntax_extensions`
+
index 202ef2b0ec097f5c47d386532b32f809f39170ef..ff1f95399825656f68cdc8f63e3d9f1a3e31dc2f 100644 (file)
@@ -5,3 +5,7 @@ Custom SQL Constructs and Compilation Extension
 
 .. automodule:: sqlalchemy.ext.compiler
     :members:
+
+
+.. autoclass:: sqlalchemy.sql.SyntaxExtension
+    :members:
index 9e38768b329883610a3f3384a3dfa9ae6fbd6a3f..8a4dd86e38dc0cce2ba3e1c42b1440f4fcaf25a5 100644 (file)
@@ -1,8 +1,8 @@
 .. _examples_toplevel:
 
-============
-ORM Examples
-============
+=====================
+Core and ORM Examples
+=====================
 
 The SQLAlchemy distribution includes a variety of code examples illustrating
 a select set of patterns, some typical and some not so typical.   All are
@@ -135,6 +135,16 @@ Horizontal Sharding
 
 .. automodule:: examples.sharding
 
+Extending Core
+==============
+
+.. _examples_syntax_extensions:
+
+Extending Statements like SELECT, INSERT, etc
+----------------------------------------------
+
+.. automodule:: examples.syntax_extensions
+
 Extending the ORM
 =================
 
diff --git a/examples/syntax_extensions/__init__.py b/examples/syntax_extensions/__init__.py
new file mode 100644 (file)
index 0000000..aa3c6b5
--- /dev/null
@@ -0,0 +1,10 @@
+"""
+A detailed example of extending the :class:`.Select` construct to include
+a new non-SQL standard clause ``QUALIFY``.
+
+This example illustrates both the :ref:`sqlalchemy.ext.compiler_toplevel`
+as well as an extension known as :class:`.SyntaxExtension`.
+
+.. autosource::
+
+"""
diff --git a/examples/syntax_extensions/qualify.py b/examples/syntax_extensions/qualify.py
new file mode 100644 (file)
index 0000000..7ab02b3
--- /dev/null
@@ -0,0 +1,67 @@
+from __future__ import annotations
+
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.sql import ClauseElement
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import ColumnElement
+from sqlalchemy.sql import ColumnExpressionArgument
+from sqlalchemy.sql import roles
+from sqlalchemy.sql import Select
+from sqlalchemy.sql import SyntaxExtension
+from sqlalchemy.sql import visitors
+
+
+def qualify(predicate: ColumnExpressionArgument[bool]) -> Qualify:
+    """Return a QUALIFY construct
+
+    E.g.::
+
+        stmt = select(qt_table).ext(
+            qualify(func.row_number().over(order_by=qt_table.c.o))
+        )
+
+    """
+    return Qualify(predicate)
+
+
+class Qualify(SyntaxExtension, ClauseElement):
+    """Define the QUALIFY class."""
+
+    predicate: ColumnElement[bool]
+    """A single column expression that is the predicate within the QUALIFY."""
+
+    _traverse_internals = [
+        ("predicate", visitors.InternalTraversal.dp_clauseelement)
+    ]
+    """This structure defines how SQLAlchemy can do a deep traverse of internal
+    contents of this structure.  This is mostly used for cache key generation.
+    If the traversal is not written yet, the ``inherit_cache=False`` class
+    level attribute may be used to skip caching for the construct.
+    """
+
+    def __init__(self, predicate: ColumnExpressionArgument):
+        self.predicate = coercions.expect(
+            roles.WhereHavingRole, predicate, apply_propagate_attrs=self
+        )
+
+    def apply_to_select(self, select_stmt: Select) -> None:
+        """Called when the :meth:`.Select.ext` method is called.
+
+        The extension should apply itself to the :class:`.Select`, typically
+        using :meth:`.HasStatementExtensions.apply_syntax_extension_point`,
+        which receives a callable that receives a list of current elements to
+        be concatenated together and then returns a new list of elements to be
+        concatenated together in the final structure.  The
+        :meth:`.SyntaxExtension.append_replacing_same_type` callable is
+        usually used for this.
+
+        """
+        select_stmt.apply_syntax_extension_point(
+            self.append_replacing_same_type, "post_criteria"
+        )
+
+
+@compiles(Qualify)
+def _compile_qualify(element, compiler, **kw):
+    """a compiles extension that delivers the SQL text for Qualify"""
+    return f"QUALIFY {compiler.process(element.predicate, **kw)}"
diff --git a/examples/syntax_extensions/test_qualify.py b/examples/syntax_extensions/test_qualify.py
new file mode 100644 (file)
index 0000000..94c90bd
--- /dev/null
@@ -0,0 +1,170 @@
+import random
+import unittest
+
+from sqlalchemy import Column
+from sqlalchemy import func
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import select
+from sqlalchemy import Table
+from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from .qualify import qualify
+
+qt_table = Table(
+    "qt",
+    MetaData(),
+    Column("i", Integer),
+    Column("p", Integer),
+    Column("o", Integer),
+)
+
+
+class QualifyCompileTest(AssertsCompiledSQL, fixtures.CacheKeySuite):
+    """A sample test suite for the QUALIFY clause, making use of SQLAlchemy
+    testing utilities.
+
+    """
+
+    __dialect__ = "default"
+
+    @fixtures.CacheKeySuite.run_suite_tests
+    def test_qualify_cache_key(self):
+        """A cache key suite using the ``CacheKeySuite.run_suite_tests``
+        decorator.
+
+        This suite intends to test that the "_traverse_internals" structure
+        of the custom SQL construct covers all the structural elements of
+        the object.    A decorated function should return a callable (e.g.
+        a lambda) which returns a list of SQL structures.    The suite will
+        call upon this lambda multiple times, to make the same list of
+        SQL structures repeatedly.  It then runs comparisons of the generated
+        cache key for each element in a particular list to all the other
+        elements in that same list, as well as other versions of the list.
+
+        The rules for this list are then as follows:
+
+        * Each element of the list should store a SQL structure that is
+          **structurally identical** each time, for a given position in the
+          list.   Successive versions of this SQL structure will be compared
+          to previous ones in the same list position and they must be
+          identical.
+
+        * Each element of the list should store a SQL structure that is
+          **structurally different** from **all other** elements in the list.
+          Successive versions of this SQL structure will be compared to
+          other members in other list positions, and they must be different
+          each time.
+
+        * The SQL structures returned in the list should exercise all of the
+          structural features that are provided by the construct.  This is
+          to ensure that two different structural elements generate a
+          different cache key and won't be mis-cached.
+
+        * Literal parameters like strings and numbers are **not** part of the
+          cache key itself since these are not "structural" elements; two
+          SQL structures that are identical can nonetheless have different
+          parameterized values.  To better exercise testing that this variation
+          is not stored as part of the cache key, ``random`` functions like
+          ``random.randint()`` or ``random.choice()`` can be used to generate
+          random literal values within a single element.
+
+
+        """
+
+        def stmt0():
+            return select(qt_table)
+
+        def stmt1():
+            stmt = stmt0()
+
+            return stmt.ext(qualify(qt_table.c.p == random.choice([2, 6, 10])))
+
+        def stmt2():
+            stmt = stmt0()
+
+            return stmt.ext(
+                qualify(func.row_number().over(order_by=qt_table.c.o))
+            )
+
+        def stmt3():
+            stmt = stmt0()
+
+            return stmt.ext(
+                qualify(
+                    func.row_number().over(
+                        partition_by=qt_table.c.i, order_by=qt_table.c.o
+                    )
+                )
+            )
+
+        return lambda: [stmt0(), stmt1(), stmt2(), stmt3()]
+
+    def test_query_one(self):
+        """A compilation test.  This makes use of the
+        ``AssertsCompiledSQL.assert_compile()`` utility.
+
+        """
+
+        stmt = select(qt_table).ext(
+            qualify(
+                func.row_number().over(
+                    partition_by=qt_table.c.p, order_by=qt_table.c.o
+                )
+                == 1
+            )
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT qt.i, qt.p, qt.o FROM qt QUALIFY row_number() "
+            "OVER (PARTITION BY qt.p ORDER BY qt.o) = :param_1",
+        )
+
+    def test_query_two(self):
+        """A compilation test.  This makes use of the
+        ``AssertsCompiledSQL.assert_compile()`` utility.
+
+        """
+
+        row_num = (
+            func.row_number()
+            .over(partition_by=qt_table.c.p, order_by=qt_table.c.o)
+            .label("row_num")
+        )
+        stmt = select(qt_table, row_num).ext(
+            qualify(row_num.as_reference() == 1)
+        )
+
+        self.assert_compile(
+            stmt,
+            "SELECT qt.i, qt.p, qt.o, row_number() OVER "
+            "(PARTITION BY qt.p ORDER BY qt.o) AS row_num "
+            "FROM qt QUALIFY row_num = :param_1",
+        )
+
+    def test_propagate_attrs(self):
+        """ORM propagate test.  this is an optional test that tests
+        apply_propagate_attrs, indicating when you pass ORM classes /
+        attributes to your construct, there's a dictionary called
+        ``._propagate_attrs`` that gets carried along to the statement,
+        which marks it as an "ORM" statement.
+
+        """
+        row_num = (
+            func.row_number().over(partition_by=qt_table.c.p).label("row_num")
+        )
+        row_num._propagate_attrs = {"foo": "bar"}
+
+        stmt = select(1).ext(qualify(row_num.as_reference() == 1))
+
+        eq_(stmt._propagate_attrs, {"foo": "bar"})
+
+
+class QualifyCompileUnittest(QualifyCompileTest, unittest.TestCase):
+    pass
+
+
+if __name__ == "__main__":
+    unittest.main()
index 158a81712b6ba84a0c568e7381a03d376840585c..fef29bd50e96638eb5992e88863694c2fde26391 100644 (file)
@@ -651,6 +651,10 @@ class _ORMCompileState(_AbstractORMCompileState):
         passed to with_polymorphic (which is completely unnecessary in modern
         use).
 
+        TODO: What is a "quasi-legacy" case?   Do we need this method with
+        2.0 style select() queries or not?   Why is with_polymorphic referring
+        to an alias or subquery "legacy" ?
+
         """
         if (
             not ext_info.is_aliased_class
@@ -862,8 +866,8 @@ class _ORMFromStatementCompileState(_ORMCompileState):
                 if opt._is_compile_state:
                     opt.process_compile_state(self)
 
-        if statement_container._with_context_options:
-            for fn, key in statement_container._with_context_options:
+        if statement_container._compile_state_funcs:
+            for fn, key in statement_container._compile_state_funcs:
                 fn(self)
 
         self.primary_columns = []
@@ -1230,8 +1234,8 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
         # after it's been set up above
         # self._dump_option_struct()
 
-        if select_statement._with_context_options:
-            for fn, key in select_statement._with_context_options:
+        if select_statement._compile_state_funcs:
+            for fn, key in select_statement._compile_state_funcs:
                 fn(self)
 
         self.primary_columns = []
@@ -1339,6 +1343,11 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
 
         self.distinct = query._distinct
 
+        self.syntax_extensions = {
+            key: current_adapter(value, True) if current_adapter else value
+            for key, value in query._get_syntax_extensions_as_dict().items()
+        }
+
         if query._correlate:
             # ORM mapped entities that are mapped to joins can be passed
             # to .correlate, so here they are broken into their component
@@ -1489,7 +1498,7 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
 
         stmt.__dict__.update(
             _with_options=statement._with_options,
-            _with_context_options=statement._with_context_options,
+            _compile_state_funcs=statement._compile_state_funcs,
             _execution_options=statement._execution_options,
             _propagate_attrs=statement._propagate_attrs,
         )
@@ -1723,6 +1732,7 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
         group_by,
         independent_ctes,
         independent_ctes_opts,
+        syntax_extensions,
     ):
         statement = Select._create_raw_select(
             _raw_columns=raw_columns,
@@ -1752,6 +1762,8 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
         statement._fetch_clause_options = fetch_clause_options
         statement._independent_ctes = independent_ctes
         statement._independent_ctes_opts = independent_ctes_opts
+        if syntax_extensions:
+            statement._set_syntax_extensions(**syntax_extensions)
 
         if prefixes:
             statement._prefixes = prefixes
@@ -2421,6 +2433,7 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
             "independent_ctes_opts": (
                 self.select_statement._independent_ctes_opts
             ),
+            "syntax_extensions": self.syntax_extensions,
         }
 
     @property
index 28c282b4872e9199a00a377e103e41f61b36ae2c..00607203c1209e969ad421ca5157cc41cb702a6c 100644 (file)
@@ -137,6 +137,7 @@ if TYPE_CHECKING:
     from ..sql._typing import _TypedColumnClauseArgument as _TCCA
     from ..sql.base import CacheableOptions
     from ..sql.base import ExecutableOption
+    from ..sql.base import SyntaxExtension
     from ..sql.dml import UpdateBase
     from ..sql.elements import ColumnElement
     from ..sql.elements import Label
@@ -209,6 +210,8 @@ class Query(
 
     _memoized_select_entities = ()
 
+    _syntax_extensions: Tuple[SyntaxExtension, ...] = ()
+
     _compile_options: Union[Type[CacheableOptions], CacheableOptions] = (
         _ORMCompileState.default_compile_options
     )
@@ -592,7 +595,7 @@ class Query(
             stmt = FromStatement(self._raw_columns, self._statement)
             stmt.__dict__.update(
                 _with_options=self._with_options,
-                _with_context_options=self._with_context_options,
+                _with_context_options=self._compile_state_funcs,
                 _compile_options=compile_options,
                 _execution_options=self._execution_options,
                 _propagate_attrs=self._propagate_attrs,
@@ -600,11 +603,14 @@ class Query(
         else:
             # Query / select() internal attributes are 99% cross-compatible
             stmt = Select._create_raw_select(**self.__dict__)
+
             stmt.__dict__.update(
                 _label_style=self._label_style,
                 _compile_options=compile_options,
                 _propagate_attrs=self._propagate_attrs,
             )
+            for ext in self._syntax_extensions:
+                stmt._apply_syntax_extension_to_self(ext)
             stmt.__dict__.pop("session", None)
 
         # ensure the ORM context is used to compile the statement, even
@@ -1425,6 +1431,7 @@ class Query(
             "_having_criteria",
             "_prefixes",
             "_suffixes",
+            "_syntax_extensions",
         ):
             self.__dict__.pop(attr, None)
         self._set_select_from([fromclause], set_entity_from)
@@ -2703,6 +2710,22 @@ class Query(
             self._distinct = True
         return self
 
+    @_generative
+    def ext(self, extension: SyntaxExtension) -> Self:
+        """Applies a SQL syntax extension to this statement.
+
+        .. seealso::
+
+            :ref:`examples_syntax_extensions`
+
+        .. versionadded:: 2.1
+
+        """
+
+        extension = coercions.expect(roles.SyntaxExtensionRole, extension)
+        self._syntax_extensions += (extension,)
+        return self
+
     def all(self) -> List[_T]:
         """Return the results represented by this :class:`_query.Query`
         as a list.
@@ -3227,6 +3250,10 @@ class Query(
             delete_ = delete_.with_dialect_options(**delete_args)
 
         delete_._where_criteria = self._where_criteria
+
+        for ext in self._syntax_extensions:
+            delete_._apply_syntax_extension_to_self(ext)
+
         result: CursorResult[Any] = self.session.execute(
             delete_,
             self._params,
@@ -3318,6 +3345,10 @@ class Query(
             upd = upd.with_dialect_options(**update_args)
 
         upd._where_criteria = self._where_criteria
+
+        for ext in self._syntax_extensions:
+            upd._apply_syntax_extension_to_self(ext)
+
         result: CursorResult[Any] = self.session.execute(
             upd,
             self._params,
index 8a5d1af961498166f304e5faf5c45aa45d58fdae..8b89eb4523819cd5514750cb539db2eb7ddc17a5 100644 (file)
@@ -1109,8 +1109,8 @@ class _LazyLoader(
                         ]
                     ).lazyload(rev).process_compile_state(compile_context)
 
-        stmt._with_context_options += (
-            (_lazyload_reverse, self.parent_property),
+        stmt = stmt._add_compile_state_func(
+            _lazyload_reverse, self.parent_property
         )
 
         lazy_clause, params = self._generate_lazy_clause(state, passive)
@@ -1774,7 +1774,7 @@ class _SubqueryLoader(_PostLoader):
                     util.to_list(self.parent_property.order_by)
                 )
 
-            q = q._add_context_option(
+            q = q._add_compile_state_func(
                 _setup_outermost_orderby, self.parent_property
             )
 
@@ -3331,7 +3331,7 @@ class _SelectInLoader(_PostLoader, util.MemoizedSlots):
                         util.to_list(self.parent_property.order_by)
                     )
 
-                q = q._add_context_option(
+                q = q._add_compile_state_func(
                     _setup_outermost_orderby, self.parent_property
                 )
 
index 188f709d7e4ab598f0aee07625c72ac1891f4384..4ac8f343d5c1bb4b972dfd0384c72e55cb5bca5e 100644 (file)
@@ -11,6 +11,7 @@ from ._typing import ColumnExpressionArgument as ColumnExpressionArgument
 from ._typing import NotNullable as NotNullable
 from ._typing import Nullable as Nullable
 from .base import Executable as Executable
+from .base import SyntaxExtension as SyntaxExtension
 from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS
 from .compiler import FROM_LINTING as FROM_LINTING
 from .compiler import NO_LINTING as NO_LINTING
index 801814f334c91a66bfa45a32b178da4f14f9f047..ee4037a2ffca31b59dc31080ca0e42ea3482ac18 100644 (file)
@@ -59,6 +59,8 @@ from ..util import HasMemoized as HasMemoized
 from ..util import hybridmethod
 from ..util.typing import Self
 from ..util.typing import TypeGuard
+from ..util.typing import TypeVarTuple
+from ..util.typing import Unpack
 
 if TYPE_CHECKING:
     from . import coercions
@@ -68,7 +70,11 @@ if TYPE_CHECKING:
     from ._orm_types import SynchronizeSessionArgument
     from ._typing import _CLE
     from .compiler import SQLCompiler
+    from .dml import Delete
+    from .dml import Insert
+    from .dml import Update
     from .elements import BindParameter
+    from .elements import ClauseElement
     from .elements import ClauseList
     from .elements import ColumnClause  # noqa
     from .elements import ColumnElement
@@ -80,6 +86,7 @@ if TYPE_CHECKING:
     from .selectable import _JoinTargetElement
     from .selectable import _SelectIterable
     from .selectable import FromClause
+    from .selectable import Select
     from ..engine import Connection
     from ..engine import CursorResult
     from ..engine.interfaces import _CoreMultiExecuteParams
@@ -100,6 +107,9 @@ if not TYPE_CHECKING:
     type_api = None  # noqa
 
 
+_Ts = TypeVarTuple("_Ts")
+
+
 class _NoArg(Enum):
     NO_ARG = 0
 
@@ -998,6 +1008,212 @@ class ExecutableOption(HasCopyInternals):
         return c
 
 
+_L = TypeVar("_L", bound=str)
+
+
+class HasSyntaxExtensions(Generic[_L]):
+
+    _position_map: Mapping[_L, str]
+
+    @_generative
+    def ext(self, extension: SyntaxExtension) -> Self:
+        """Applies a SQL syntax extension to this statement.
+
+        SQL syntax extensions are :class:`.ClauseElement` objects that define
+        some vendor-specific syntactical construct that take place in specific
+        parts of a SQL statement.   Examples include vendor extensions like
+        PostgreSQL / SQLite's "ON DUPLICATE KEY UPDATE", PostgreSQL's
+        "DISTINCT ON", and MySQL's "LIMIT" that can be applied to UPDATE
+        and DELETE statements.
+
+        .. seealso::
+
+            :ref:`examples_syntax_extensions`
+
+        .. versionadded:: 2.1
+
+        """
+        extension = coercions.expect(
+            roles.SyntaxExtensionRole, extension, apply_propagate_attrs=self
+        )
+        self._apply_syntax_extension_to_self(extension)
+        return self
+
+    @util.preload_module("sqlalchemy.sql.elements")
+    def apply_syntax_extension_point(
+        self,
+        apply_fn: Callable[[Sequence[ClauseElement]], Sequence[ClauseElement]],
+        position: _L,
+    ) -> None:
+        """Apply a :class:`.SyntaxExtension` to a known extension point.
+
+        Should be used only internally by :class:`.SyntaxExtension`.
+
+        E.g.::
+
+            class Qualify(SyntaxExtension, ClauseElement):
+
+                # ...
+
+                def apply_to_select(self, select_stmt: Select) -> None:
+                    # append self to existing
+                    select_stmt.apply_extension_point(
+                        lambda existing: [*existing, self], "post_criteria"
+                    )
+
+
+            class ReplaceExt(SyntaxExtension, ClauseElement):
+
+                # ...
+
+                def apply_to_select(self, select_stmt: Select) -> None:
+                    # replace any existing elements regardless of type
+                    select_stmt.apply_extension_point(
+                        lambda existing: [self], "post_criteria"
+                    )
+
+
+            class ReplaceOfTypeExt(SyntaxExtension, ClauseElement):
+
+                # ...
+
+                def apply_to_select(self, select_stmt: Select) -> None:
+                    # replace any existing elements of the same type
+                    select_stmt.apply_extension_point(
+                        self.append_replacing_same_type, "post_criteria"
+                    )
+
+        :param apply_fn: callable function that will receive a sequence of
+         :class:`.ClauseElement` that is already populating the extension
+         point (the sequence is empty if there isn't one), and should return
+         a new sequence of :class:`.ClauseElement` that will newly populate
+         that point. The function typically can choose to concatenate the
+         existing values with the new one, or to replace the values that are
+         there with a new one by returning a list of a single element, or
+         to perform more complex operations like removing only the same
+         type element from the input list of merging already existing elements
+         of the same type. Some examples are shown in the examples above
+        :param position: string name of the position to apply to.  This
+         varies per statement type.   IDEs should show the possible values
+         for each statement type as it's typed with a ``typing.Literal`` per
+         statement.
+
+        .. seealso::
+
+            :ref:`examples_syntax_extensions`
+
+
+        """  # noqa: E501
+
+        try:
+            attrname = self._position_map[position]
+        except KeyError as ke:
+            raise ValueError(
+                f"Unknown position {position!r} for {self.__class__} "
+                f"construct; known positions: "
+                f"{', '.join(repr(k) for k in self._position_map)}"
+            ) from ke
+        else:
+            ElementList = util.preloaded.sql_elements.ElementList
+            existing: Optional[ClauseElement] = getattr(self, attrname, None)
+            if existing is None:
+                input_seq: Tuple[ClauseElement, ...] = ()
+            elif isinstance(existing, ElementList):
+                input_seq = existing.clauses
+            else:
+                input_seq = (existing,)
+
+            new_seq = apply_fn(input_seq)
+            assert new_seq, "cannot return empty sequence"
+            new = new_seq[0] if len(new_seq) == 1 else ElementList(new_seq)
+            setattr(self, attrname, new)
+
+    def _apply_syntax_extension_to_self(
+        self, extension: SyntaxExtension
+    ) -> None:
+        raise NotImplementedError()
+
+    def _get_syntax_extensions_as_dict(self) -> Mapping[_L, SyntaxExtension]:
+        res: Dict[_L, SyntaxExtension] = {}
+        for name, attr in self._position_map.items():
+            value = getattr(self, attr)
+            if value is not None:
+                res[name] = value
+        return res
+
+    def _set_syntax_extensions(self, **extensions: SyntaxExtension) -> None:
+        for name, value in extensions.items():
+            setattr(self, self._position_map[name], value)  # type: ignore[index]  # noqa: E501
+
+
+class SyntaxExtension(roles.SyntaxExtensionRole):
+    """Defines a unit that when also extending from :class:`.ClauseElement`
+    can be applied to SQLAlchemy statements :class:`.Select`,
+    :class:`_sql.Insert`, :class:`.Update` and :class:`.Delete` making use of
+    pre-established SQL insertion points within these constructs.
+
+    .. versionadded:: 2.1
+
+    .. seealso::
+
+        :ref:`examples_syntax_extensions`
+
+    """
+
+    def append_replacing_same_type(
+        self, existing: Sequence[ClauseElement]
+    ) -> Sequence[ClauseElement]:
+        """Utility function that can be used as
+        :paramref:`_sql.HasSyntaxExtensions.apply_extension_point.apply_fn`
+        to remove any other element of the same type in existing and appending
+        ``self`` to the list.
+
+        This is equivalent to::
+
+            stmt.apply_extension_point(
+                lambda existing: [
+                    *(e for e in existing if not isinstance(e, ReplaceOfTypeExt)),
+                    self,
+                ],
+                "post_criteria",
+            )
+
+        .. seealso::
+
+            :ref:`examples_syntax_extensions`
+
+            :meth:`_sql.HasSyntaxExtensions.apply_syntax_extension_point`
+
+        """  # noqa: E501
+        cls = type(self)
+        return [*(e for e in existing if not isinstance(e, cls)), self]  # type: ignore[list-item] # noqa: E501
+
+    def apply_to_select(self, select_stmt: Select[Unpack[_Ts]]) -> None:
+        """Apply this :class:`.SyntaxExtension` to a :class:`.Select`"""
+        raise NotImplementedError(
+            f"Extension {type(self).__name__} cannot be applied to select"
+        )
+
+    def apply_to_update(self, update_stmt: Update) -> None:
+        """Apply this :class:`.SyntaxExtension` to an :class:`.Update`"""
+        raise NotImplementedError(
+            f"Extension {type(self).__name__} cannot be applied to update"
+        )
+
+    def apply_to_delete(self, delete_stmt: Delete) -> None:
+        """Apply this :class:`.SyntaxExtension` to a :class:`.Delete`"""
+        raise NotImplementedError(
+            f"Extension {type(self).__name__} cannot be applied to delete"
+        )
+
+    def apply_to_insert(self, insert_stmt: Insert) -> None:
+        """Apply this :class:`.SyntaxExtension` to an
+        :class:`_sql.Insert`"""
+        raise NotImplementedError(
+            f"Extension {type(self).__name__} cannot be applied to insert"
+        )
+
+
 class Executable(roles.StatementRole):
     """Mark a :class:`_expression.ClauseElement` as supporting execution.
 
@@ -1011,7 +1227,7 @@ class Executable(roles.StatementRole):
     _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT
     _is_default_generator = False
     _with_options: Tuple[ExecutableOption, ...] = ()
-    _with_context_options: Tuple[
+    _compile_state_funcs: Tuple[
         Tuple[Callable[[CompileState], None], Any], ...
     ] = ()
     _compile_options: Optional[Union[Type[CacheableOptions], CacheableOptions]]
@@ -1019,8 +1235,8 @@ class Executable(roles.StatementRole):
     _executable_traverse_internals = [
         ("_with_options", InternalTraversal.dp_executable_options),
         (
-            "_with_context_options",
-            ExtendedInternalTraversal.dp_with_context_options,
+            "_compile_state_funcs",
+            ExtendedInternalTraversal.dp_compile_state_funcs,
         ),
         ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs),
     ]
@@ -1076,14 +1292,10 @@ class Executable(roles.StatementRole):
         """Apply options to this statement.
 
         In the general sense, options are any kind of Python object
-        that can be interpreted by the SQL compiler for the statement.
-        These options can be consumed by specific dialects or specific kinds
-        of compilers.
-
-        The most commonly known kind of option are the ORM level options
-        that apply "eager load" and other loading behaviors to an ORM
-        query.   However, options can theoretically be used for many other
-        purposes.
+        that can be interpreted by systems that consume the statement outside
+        of the regular SQL compiler chain.  Specifically, these options are
+        the ORM level options that apply "eager load" and other loading
+        behaviors to an ORM query.
 
         For background on specific kinds of options for specific kinds of
         statements, refer to the documentation for those option objects.
@@ -1127,14 +1339,14 @@ class Executable(roles.StatementRole):
         return self
 
     @_generative
-    def _add_context_option(
+    def _add_compile_state_func(
         self,
         callable_: Callable[[CompileState], None],
         cache_args: Any,
     ) -> Self:
-        """Add a context option to this statement.
+        """Add a compile state function to this statement.
 
-        These are callable functions that will
+        When using the ORM only, these are callable functions that will
         be given the CompileState object upon compilation.
 
         A second argument cache_args is required, which will be combined with
@@ -1142,7 +1354,7 @@ class Executable(roles.StatementRole):
         cache key.
 
         """
-        self._with_context_options += ((callable_, cache_args),)
+        self._compile_state_funcs += ((callable_, cache_args),)
         return self
 
     @overload
index 189c32b27169f6dba86867327cdf8e42ca2987de..5ac11878bac5214528ef7de8491980057aecb3a6 100644 (file)
@@ -478,10 +478,10 @@ class CacheKey(NamedTuple):
         return repr((sql_str, param_tuple))
 
     def __eq__(self, other: Any) -> bool:
-        return bool(self.key == other.key)
+        return other is not None and bool(self.key == other.key)
 
     def __ne__(self, other: Any) -> bool:
-        return not (self.key == other.key)
+        return other is None or not (self.key == other.key)
 
     @classmethod
     def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str:
@@ -629,7 +629,7 @@ class _CacheKeyTraversal(HasTraversalDispatch):
 
     visit_propagate_attrs = PROPAGATE_ATTRS
 
-    def visit_with_context_options(
+    def visit_compile_state_funcs(
         self,
         attrname: str,
         obj: Any,
index 39655e56d940ca774e5062009664a34d9d45cc2a..fc3614c06ba4218421249946d61759b9fbc4dc94 100644 (file)
@@ -52,6 +52,7 @@ if typing.TYPE_CHECKING:
     from ._typing import _DDLColumnArgument
     from ._typing import _DMLTableArgument
     from ._typing import _FromClauseArgument
+    from .base import SyntaxExtension
     from .dml import _DMLTableElement
     from .elements import BindParameter
     from .elements import ClauseElement
@@ -209,6 +210,14 @@ def expect(
 ) -> Union[ColumnElement[Any], TextClause]: ...
 
 
+@overload
+def expect(
+    role: Type[roles.SyntaxExtensionRole],
+    element: Any,
+    **kw: Any,
+) -> SyntaxExtension: ...
+
+
 @overload
 def expect(
     role: Type[roles.LabeledColumnExprRole[Any]],
@@ -926,6 +935,10 @@ class WhereHavingImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl):
         return _no_text_coercion(element, argname)
 
 
+class SyntaxExtensionImpl(RoleImpl):
+    __slots__ = ()
+
+
 class StatementOptionImpl(_CoerceLiterals, RoleImpl):
     __slots__ = ()
 
index 9f718133167589a0c6a131d9855a7a4b01aa1ce6..1ee9ff077721ad8e60cfd7b0321ffdb77908bfdb 100644 (file)
@@ -2778,6 +2778,9 @@ class SQLCompiler(Compiled):
     def visit_tuple(self, clauselist, **kw):
         return "(%s)" % self.visit_clauselist(clauselist, **kw)
 
+    def visit_element_list(self, element, **kw):
+        return self._generate_delimited_list(element.clauses, " ", **kw)
+
     def visit_clauselist(self, clauselist, **kw):
         sep = clauselist.operator
         if sep is None:
@@ -4744,6 +4747,11 @@ class SQLCompiler(Compiled):
 
         text = "SELECT "  # we're off to a good start !
 
+        if select_stmt._post_select_clause is not None:
+            psc = self.process(select_stmt._post_select_clause, **kwargs)
+            if psc is not None:
+                text += psc + " "
+
         if select_stmt._hints:
             hint_text, byfrom = self._setup_select_hints(select_stmt)
             if hint_text:
@@ -4760,6 +4768,12 @@ class SQLCompiler(Compiled):
             )
 
         text += self.get_select_precolumns(select_stmt, **kwargs)
+
+        if select_stmt._pre_columns_clause is not None:
+            pcc = self.process(select_stmt._pre_columns_clause, **kwargs)
+            if pcc is not None:
+                text += pcc + " "
+
         # the actual list of columns to print in the SELECT column list.
         inner_columns = [
             c
@@ -4834,6 +4848,11 @@ class SQLCompiler(Compiled):
             kwargs,
         )
 
+        if select_stmt._post_body_clause is not None:
+            pbc = self.process(select_stmt._post_body_clause, **kwargs)
+            if pbc:
+                text += " " + pbc
+
         if select_stmt._statement_hints:
             per_dialect = [
                 ht
@@ -5005,6 +5024,11 @@ class SQLCompiler(Compiled):
             if t:
                 text += " \nHAVING " + t
 
+        if select._post_criteria_clause is not None:
+            pcc = self.process(select._post_criteria_clause, **kwargs)
+            if pcc is not None:
+                text += " \n" + pcc
+
         if select._order_by_clauses:
             text += self.order_by_clause(select, **kwargs)
 
@@ -6134,9 +6158,7 @@ class SQLCompiler(Compiled):
     ):
         """Provide a hook to override the generation of an
         UPDATE..FROM clause.
-
         MySQL and MSSQL override this.
-
         """
         raise NotImplementedError(
             "This backend does not support multiple-table "
@@ -6263,6 +6285,16 @@ class SQLCompiler(Compiled):
         if limit_clause:
             text += " " + limit_clause
 
+        if update_stmt._post_criteria_clause is not None:
+            ulc = self.process(
+                update_stmt._post_criteria_clause,
+                from_linter=from_linter,
+                **kw,
+            )
+
+            if ulc:
+                text += " " + ulc
+
         if (
             self.implicit_returning or update_stmt._returning
         ) and not self.returning_precedes_values:
@@ -6415,6 +6447,15 @@ class SQLCompiler(Compiled):
         if limit_clause:
             text += " " + limit_clause
 
+        if delete_stmt._post_criteria_clause is not None:
+            dlc = self.process(
+                delete_stmt._post_criteria_clause,
+                from_linter=from_linter,
+                **kw,
+            )
+            if dlc:
+                text += " " + dlc
+
         if (
             self.implicit_returning or delete_stmt._returning
         ) and not self.returning_precedes_values:
index e9a59350e344a4d64dce663c16cd19f1a3c0b717..49a43b8eeeee4b9d7872e74237a25efaa547a3c0 100644 (file)
@@ -18,6 +18,7 @@ from typing import cast
 from typing import Dict
 from typing import Iterable
 from typing import List
+from typing import Literal
 from typing import MutableMapping
 from typing import NoReturn
 from typing import Optional
@@ -48,6 +49,8 @@ from .base import DialectKWArgs
 from .base import Executable
 from .base import Generative
 from .base import HasCompileState
+from .base import HasSyntaxExtensions
+from .base import SyntaxExtension
 from .elements import BooleanClauseList
 from .elements import ClauseElement
 from .elements import ColumnClause
@@ -988,7 +991,7 @@ class ValuesBase(UpdateBase):
     """SELECT statement for INSERT .. FROM SELECT"""
 
     _post_values_clause: Optional[ClauseElement] = None
-    """used by extensions to Insert etc. to add additional syntacitcal
+    """used by extensions to Insert etc. to add additional syntactical
     constructs, e.g. ON CONFLICT etc."""
 
     _values: Optional[util.immutabledict[_DMLColumnElement, Any]] = None
@@ -1190,12 +1193,16 @@ class ValuesBase(UpdateBase):
         return self
 
 
-class Insert(ValuesBase):
+class Insert(ValuesBase, HasSyntaxExtensions[Literal["post_values"]]):
     """Represent an INSERT construct.
 
     The :class:`_expression.Insert` object is created using the
     :func:`_expression.insert()` function.
 
+    Available extension points:
+
+    * ``post_values``: applies additional logic after the ``VALUES`` clause.
+
     """
 
     __visit_name__ = "insert"
@@ -1235,9 +1242,26 @@ class Insert(ValuesBase):
         + HasCTE._has_ctes_traverse_internals
     )
 
+    _position_map = util.immutabledict(
+        {
+            "post_values": "_post_values_clause",
+        }
+    )
+
+    _post_values_clause: Optional[ClauseElement] = None
+    """extension point for a ClauseElement that will be compiled directly
+    after the VALUES portion of the :class:`.Insert` statement
+
+    """
+
     def __init__(self, table: _DMLTableArgument):
         super().__init__(table)
 
+    def _apply_syntax_extension_to_self(
+        self, extension: SyntaxExtension
+    ) -> None:
+        extension.apply_to_insert(self)
+
     @_generative
     def inline(self) -> Self:
         """Make this :class:`_expression.Insert` construct "inline" .
@@ -1452,10 +1476,25 @@ class ReturningInsert(Insert, TypedReturnsRows[Unpack[_Ts]]):
     """
 
 
+# note: if not for MRO issues, this class should extend
+# from HasSyntaxExtensions[Literal["post_criteria"]]
 class DMLWhereBase:
     table: _DMLTableElement
     _where_criteria: Tuple[ColumnElement[Any], ...] = ()
 
+    _post_criteria_clause: Optional[ClauseElement] = None
+    """used by extensions to Update/Delete etc. to add additional syntacitcal
+    constructs, e.g. LIMIT etc.
+
+    .. versionadded:: 2.1
+
+    """
+
+    # can't put position_map here either without HasSyntaxExtensions
+    # _position_map = util.immutabledict(
+    #     {"post_criteria": "_post_criteria_clause"}
+    # )
+
     @_generative
     def where(self, *whereclause: _ColumnExpressionArgument[bool]) -> Self:
         """Return a new construct with the given expression(s) added to
@@ -1528,12 +1567,18 @@ class DMLWhereBase:
         )
 
 
-class Update(DMLWhereBase, ValuesBase):
+class Update(
+    DMLWhereBase, ValuesBase, HasSyntaxExtensions[Literal["post_criteria"]]
+):
     """Represent an Update construct.
 
     The :class:`_expression.Update` object is created using the
     :func:`_expression.update()` function.
 
+    Available extension points:
+
+    * ``post_criteria``: applies additional logic after the ``WHERE`` clause.
+
     """
 
     __visit_name__ = "update"
@@ -1550,6 +1595,7 @@ class Update(DMLWhereBase, ValuesBase):
             ("_returning", InternalTraversal.dp_clauseelement_tuple),
             ("_hints", InternalTraversal.dp_table_hint_list),
             ("_return_defaults", InternalTraversal.dp_boolean),
+            ("_post_criteria_clause", InternalTraversal.dp_clauseelement),
             (
                 "_return_defaults_columns",
                 InternalTraversal.dp_clauseelement_tuple,
@@ -1561,6 +1607,10 @@ class Update(DMLWhereBase, ValuesBase):
         + HasCTE._has_ctes_traverse_internals
     )
 
+    _position_map = util.immutabledict(
+        {"post_criteria": "_post_criteria_clause"}
+    )
+
     def __init__(self, table: _DMLTableArgument):
         super().__init__(table)
 
@@ -1618,6 +1668,11 @@ class Update(DMLWhereBase, ValuesBase):
         self._inline = True
         return self
 
+    def _apply_syntax_extension_to_self(
+        self, extension: SyntaxExtension
+    ) -> None:
+        extension.apply_to_update(self)
+
     if TYPE_CHECKING:
         # START OVERLOADED FUNCTIONS self.returning ReturningUpdate 1-8
 
@@ -1724,12 +1779,18 @@ class ReturningUpdate(Update, TypedReturnsRows[Unpack[_Ts]]):
     """
 
 
-class Delete(DMLWhereBase, UpdateBase):
+class Delete(
+    DMLWhereBase, UpdateBase, HasSyntaxExtensions[Literal["post_criteria"]]
+):
     """Represent a DELETE construct.
 
     The :class:`_expression.Delete` object is created using the
     :func:`_expression.delete()` function.
 
+    Available extension points:
+
+    * ``post_criteria``: applies additional logic after the ``WHERE`` clause.
+
     """
 
     __visit_name__ = "delete"
@@ -1742,6 +1803,7 @@ class Delete(DMLWhereBase, UpdateBase):
             ("_where_criteria", InternalTraversal.dp_clauseelement_tuple),
             ("_returning", InternalTraversal.dp_clauseelement_tuple),
             ("_hints", InternalTraversal.dp_table_hint_list),
+            ("_post_criteria_clause", InternalTraversal.dp_clauseelement),
         ]
         + HasPrefixes._has_prefixes_traverse_internals
         + DialectKWArgs._dialect_kwargs_traverse_internals
@@ -1749,11 +1811,20 @@ class Delete(DMLWhereBase, UpdateBase):
         + HasCTE._has_ctes_traverse_internals
     )
 
+    _position_map = util.immutabledict(
+        {"post_criteria": "_post_criteria_clause"}
+    )
+
     def __init__(self, table: _DMLTableArgument):
         self.table = coercions.expect(
             roles.DMLTableRole, table, apply_propagate_attrs=self
         )
 
+    def _apply_syntax_extension_to_self(
+        self, extension: SyntaxExtension
+    ) -> None:
+        extension.apply_to_delete(self)
+
     if TYPE_CHECKING:
         # START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8
 
index bd92f6aa854e0ae8c64174c26a5e3174c6f67243..8d256ea3772f077a0636306b9637af77e4d601a3 100644 (file)
@@ -2717,6 +2717,29 @@ class True_(SingletonConstant, roles.ConstExprRole[bool], ColumnElement[bool]):
 True_._create_singleton()
 
 
+class ElementList(DQLDMLClauseElement):
+    """Describe a list of clauses that will be space separated.
+
+    This is a minimal version of :class:`.ClauseList` which is used by
+    the :class:`.HasSyntaxExtension` class.  It does not do any coercions
+    so should be used internally only.
+
+    .. versionadded:: 2.1
+
+    """
+
+    __visit_name__ = "element_list"
+
+    _traverse_internals: _TraverseInternalsType = [
+        ("clauses", InternalTraversal.dp_clauseelement_tuple),
+    ]
+
+    clauses: typing_Tuple[ClauseElement, ...]
+
+    def __init__(self, clauses: Sequence[ClauseElement]):
+        self.clauses = tuple(clauses)
+
+
 class ClauseList(
     roles.InElementRole,
     roles.OrderByRole,
@@ -3580,6 +3603,7 @@ class _label_reference(ColumnElement[_T]):
 
     def __init__(self, element: ColumnElement[_T]):
         self.element = element
+        self._propagate_attrs = element._propagate_attrs
 
     @util.ro_non_memoized_property
     def _from_objects(self) -> List[FromClause]:
@@ -4787,6 +4811,16 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]):
     def _order_by_label_element(self):
         return self
 
+    def as_reference(self) -> _label_reference[_T]:
+        """refer to this labeled expression in a clause such as GROUP BY,
+        ORDER BY etc. as the label name itself, without expanding
+        into the full expression.
+
+        .. versionadded:: 2.1
+
+        """
+        return _label_reference(self)
+
     @HasMemoized.memoized_attribute
     def element(self) -> ColumnElement[_T]:
         return self._element.self_group(against=operators.as_)
index 9c5e43baaccb14f44bc5c9b46327301fea9a5f2e..99f9fc231c44c126eda774c094bdd3d21dfb3af0 100644 (file)
@@ -42,6 +42,11 @@ class SQLRole:
     uses_inspection = False
 
 
+class SyntaxExtensionRole(SQLRole):
+    __slots__ = ()
+    _role_name = "Syntax extension construct"
+
+
 class UsesInspection:
     __slots__ = ()
     _post_inspect: Literal[None] = None
index e53b2bbccc1cc714b0eaa9d8a57718766ddbf711..40f9dbe00425f55563c2a47c077238f289c598c1 100644 (file)
@@ -77,7 +77,9 @@ from .base import Executable
 from .base import Generative
 from .base import HasCompileState
 from .base import HasMemoized
+from .base import HasSyntaxExtensions
 from .base import Immutable
+from .base import SyntaxExtension
 from .coercions import _document_text_coercion
 from .elements import _anonymous_label
 from .elements import BindParameter
@@ -5217,6 +5219,9 @@ class Select(
     HasSuffixes,
     HasHints,
     HasCompileState,
+    HasSyntaxExtensions[
+        Literal["post_select", "pre_columns", "post_criteria", "post_body"]
+    ],
     _SelectFromElements,
     GenerativeSelect,
     TypedReturnsRows[Unpack[_Ts]],
@@ -5226,6 +5231,14 @@ class Select(
     The :class:`_sql.Select` object is normally constructed using the
     :func:`_sql.select` function.  See that function for details.
 
+    Available extension points:
+
+    * ``post_select``: applies additional logic after the ``SELECT`` keyword.
+    * ``pre_columns``: applies additional logic between the ``DISTINCT``
+        keyword (if any) and the list of columns.
+    * ``post_criteria``: applies additional logic after the ``HAVING`` clause.
+    * ``post_body``: applies additional logic after the ``FOR UPDATE`` clause.
+
     .. seealso::
 
         :func:`_sql.select`
@@ -5248,6 +5261,49 @@ class Select(
     _where_criteria: Tuple[ColumnElement[Any], ...] = ()
     _having_criteria: Tuple[ColumnElement[Any], ...] = ()
     _from_obj: Tuple[FromClause, ...] = ()
+
+    _position_map = util.immutabledict(
+        {
+            "post_select": "_post_select_clause",
+            "pre_columns": "_pre_columns_clause",
+            "post_criteria": "_post_criteria_clause",
+            "post_body": "_post_body_clause",
+        }
+    )
+
+    _post_select_clause: Optional[ClauseElement] = None
+    """extension point for a ClauseElement that will be compiled directly
+    after the SELECT keyword.
+
+    .. versionadded:: 2.1
+
+    """
+
+    _pre_columns_clause: Optional[ClauseElement] = None
+    """extension point for a ClauseElement that will be compiled directly
+    before the "columns" clause; after DISTINCT (if present).
+
+    .. versionadded:: 2.1
+
+    """
+
+    _post_criteria_clause: Optional[ClauseElement] = None
+    """extension point for a ClauseElement that will be compiled directly
+    after "criteria", following the HAVING clause but before ORDER BY.
+
+    .. versionadded:: 2.1
+
+    """
+
+    _post_body_clause: Optional[ClauseElement] = None
+    """extension point for a ClauseElement that will be compiled directly
+    after the "body", following the ORDER BY, LIMIT, and FOR UPDATE sections
+    of the SELECT.
+
+    .. versionadded:: 2.1
+
+    """
+
     _auto_correlate = True
     _is_select_statement = True
     _compile_options: CacheableOptions = (
@@ -5277,6 +5333,10 @@ class Select(
             ("_distinct", InternalTraversal.dp_boolean),
             ("_distinct_on", InternalTraversal.dp_clauseelement_tuple),
             ("_label_style", InternalTraversal.dp_plain_obj),
+            ("_post_select_clause", InternalTraversal.dp_clauseelement),
+            ("_pre_columns_clause", InternalTraversal.dp_clauseelement),
+            ("_post_criteria_clause", InternalTraversal.dp_clauseelement),
+            ("_post_body_clause", InternalTraversal.dp_clauseelement),
         ]
         + HasCTE._has_ctes_traverse_internals
         + HasPrefixes._has_prefixes_traverse_internals
@@ -5321,6 +5381,11 @@ class Select(
 
         GenerativeSelect.__init__(self)
 
+    def _apply_syntax_extension_to_self(
+        self, extension: SyntaxExtension
+    ) -> None:
+        extension.apply_to_select(self)
+
     def _scalar_type(self) -> TypeEngine[Any]:
         if not self._raw_columns:
             return NULLTYPE
index 13ad28996e086cddf458dba855b74eda895b5028..38f8e3e162355017efaa43705c3baf0b1b2efdc4 100644 (file)
@@ -668,6 +668,19 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
             for l, r in zip_longest(ltup, rtup, fillvalue=None):
                 self.stack.append((l, r))
 
+    def visit_multi_list(
+        self, attrname, left_parent, left, right_parent, right, **kw
+    ):
+        for l, r in zip_longest(left, right, fillvalue=None):
+            if isinstance(l, str):
+                if not isinstance(r, str) or l != r:
+                    return COMPARE_FAILED
+            elif isinstance(r, str):
+                if not isinstance(l, str) or l != r:
+                    return COMPARE_FAILED
+            else:
+                self.stack.append((l, r))
+
     def visit_clauseelement_list(
         self, attrname, left_parent, left, right_parent, right, **kw
     ):
@@ -796,7 +809,7 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
         else:
             return left == right
 
-    def visit_with_context_options(
+    def visit_compile_state_funcs(
         self, attrname, left_parent, left, right_parent, right, **kw
     ):
         return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
index 7ae89216877a083bec4bf29dbfa00f82783622a6..34ac84953bcb875b22f8b264403bf9b21d9b35f8 100644 (file)
@@ -218,7 +218,7 @@ class InternalTraversal(Enum):
 
     dp_executable_options = "EO"
 
-    dp_with_context_options = "WC"
+    dp_compile_state_funcs = "WC"
 
     dp_fromclause_ordered_set = "CO"
     """Visit an ordered set of :class:`_expression.FromClause` objects. """
index f2948dee8d3dc86d520fe695a013ceeed3e0b177..ae88818300a1f7d77dca49090c2d8d866c86cf15 100644 (file)
@@ -19,6 +19,7 @@ from .orm import (
     stop_test_class_inside_fixtures as stop_test_class_inside_fixtures,
 )
 from .sql import CacheKeyFixture as CacheKeyFixture
+from .sql import CacheKeySuite as CacheKeySuite
 from .sql import (
     ComputedReflectionFixtureTest as ComputedReflectionFixtureTest,
 )
index 44cf21c24fe08ee4ecaee85bab554523ad347331..d1f06683f1bf906274c0cf6737c8070325a5c098 100644 (file)
@@ -341,12 +341,15 @@ class ComputedReflectionFixtureTest(TablesTest):
 
 
 class CacheKeyFixture:
-    def _compare_equal(self, a, b, compare_values):
+    def _compare_equal(self, a, b, *, compare_values=False):
         a_key = a._generate_cache_key()
         b_key = b._generate_cache_key()
 
         if a_key is None:
-            assert a._annotations.get("nocache")
+            assert a._annotations.get("nocache"), (
+                "Construct doesn't cache, so test suite should "
+                "add the 'nocache' annotation"
+            )
 
             assert b_key is None
         else:
@@ -357,7 +360,23 @@ class CacheKeyFixture:
                 assert a_param.compare(b_param, compare_values=compare_values)
         return a_key, b_key
 
-    def _run_cache_key_fixture(self, fixture, compare_values):
+    def _run_compare_fixture(self, fixture, *, compare_values=False):
+        case_a = fixture()
+        case_b = fixture()
+
+        for a, b in itertools.combinations_with_replacement(
+            range(len(case_a)), 2
+        ):
+            if a == b:
+                assert case_a[a].compare(
+                    case_b[b], compare_values=compare_values
+                )
+            else:
+                assert not case_a[a].compare(
+                    case_b[b], compare_values=compare_values
+                )
+
+    def _run_cache_key_fixture(self, fixture, *, compare_values=False):
         case_a = fixture()
         case_b = fixture()
 
@@ -366,7 +385,7 @@ class CacheKeyFixture:
         ):
             if a == b:
                 a_key, b_key = self._compare_equal(
-                    case_a[a], case_b[b], compare_values
+                    case_a[a], case_b[b], compare_values=compare_values
                 )
                 if a_key is None:
                     continue
@@ -439,7 +458,20 @@ class CacheKeyFixture:
         for a, b in itertools.combinations_with_replacement(
             range(len(case_a)), 2
         ):
-            self._compare_equal(case_a[a], case_b[b], compare_values)
+            self._compare_equal(
+                case_a[a], case_b[b], compare_values=compare_values
+            )
+
+
+class CacheKeySuite(CacheKeyFixture):
+    @classmethod
+    def run_suite_tests(cls, fn):
+        def decorate(self):
+            self._run_cache_key_fixture(fn(self), compare_values=False)
+            self._run_compare_fixture(fn(self), compare_values=False)
+
+        decorate.__name__ = fn.__name__
+        return decorate
 
 
 def insertmanyvalues_fixture(
index 4baddfb105ae9713aab1f2e83461f1c02d7a92ad..4ccdd29b2d19fb3f279f6e53787114fa5e05a9ef 100644 (file)
@@ -29,3 +29,12 @@ class VersionedRowsTestNewBase(
     fixtures.TestBase,
 ):
     pass
+
+
+test_qualify = __import__(
+    "examples.syntax_extensions.test_qualify"
+).syntax_extensions.test_qualify
+
+
+class QualifyCompileTest(test_qualify.QualifyCompileTest, fixtures.TestBase):
+    pass
diff --git a/test/orm/test_syntax_extensions.py b/test/orm/test_syntax_extensions.py
new file mode 100644 (file)
index 0000000..08a366c
--- /dev/null
@@ -0,0 +1,264 @@
+from __future__ import annotations
+
+from typing import Any
+
+from sqlalchemy import insert
+from sqlalchemy import inspect
+from sqlalchemy import select
+from sqlalchemy import testing
+from sqlalchemy import update
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.sql import ClauseElement
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import roles
+from sqlalchemy.sql._typing import _ColumnExpressionArgument
+from sqlalchemy.sql.base import SyntaxExtension
+from sqlalchemy.sql.dml import Delete
+from sqlalchemy.sql.dml import Update
+from sqlalchemy.sql.visitors import _TraverseInternalsType
+from sqlalchemy.sql.visitors import InternalTraversal
+from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import eq_
+from .test_query import QueryTest
+
+
+class PostSelectClause(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            lambda existing: [*existing, self],
+            "post_select",
+        )
+
+
+class PreColumnsClause(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            lambda existing: [*existing, self],
+            "pre_columns",
+        )
+
+
+class PostCriteriaClause(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            lambda existing: [*existing, self],
+            "post_criteria",
+        )
+
+    def apply_to_update(self, update_stmt: Update) -> None:
+        update_stmt.apply_syntax_extension_point(
+            lambda existing: [self], "post_criteria"
+        )
+
+    def apply_to_delete(self, delete_stmt: Delete) -> None:
+        delete_stmt.apply_syntax_extension_point(
+            lambda existing: [self], "post_criteria"
+        )
+
+
+class PostCriteriaClause2(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            lambda existing: [*existing, self],
+            "post_criteria",
+        )
+
+
+class PostCriteriaClauseCols(PostCriteriaClause):
+    _traverse_internals: _TraverseInternalsType = [
+        ("exprs", InternalTraversal.dp_clauseelement_tuple),
+    ]
+
+    def __init__(self, *exprs: _ColumnExpressionArgument[Any]):
+        self.exprs = tuple(
+            coercions.expect(roles.ByOfRole, e, apply_propagate_attrs=self)
+            for e in exprs
+        )
+
+
+class PostCriteriaClauseColsNoProp(PostCriteriaClause):
+    _traverse_internals: _TraverseInternalsType = [
+        ("exprs", InternalTraversal.dp_clauseelement_tuple),
+    ]
+
+    def __init__(self, *exprs: _ColumnExpressionArgument[Any]):
+        self.exprs = tuple(coercions.expect(roles.ByOfRole, e) for e in exprs)
+
+
+class PostBodyClause(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            lambda existing: [self],
+            "post_body",
+        )
+
+
+class PostValuesClause(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_insert(self, insert_stmt):
+        insert_stmt.apply_syntax_extension_point(
+            lambda existing: [self],
+            "post_values",
+        )
+
+
+@compiles(PostSelectClause)
+def _compile_psk(element, compiler, **kw):
+    return "POST SELECT KEYWORD"
+
+
+@compiles(PreColumnsClause)
+def _compile_pcc(element, compiler, **kw):
+    return "PRE COLUMNS"
+
+
+@compiles(PostCriteriaClause)
+def _compile_psc(element, compiler, **kw):
+    return "POST CRITERIA"
+
+
+@compiles(PostCriteriaClause2)
+def _compile_psc2(element, compiler, **kw):
+    return "2 POST CRITERIA 2"
+
+
+@compiles(PostCriteriaClauseCols)
+def _compile_psc_cols(element, compiler, **kw):
+    return f"""PC COLS ({
+        ', '.join(compiler.process(expr, **kw) for expr in element.exprs)
+    })"""
+
+
+@compiles(PostBodyClause)
+def _compile_psb(element, compiler, **kw):
+    return "POST SELECT BODY"
+
+
+@compiles(PostValuesClause)
+def _compile_pvc(element, compiler, **kw):
+    return "POST VALUES"
+
+
+class TestExtensionPoints(QueryTest, AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    def test_select_post_select_clause(self):
+        User = self.classes.User
+
+        stmt = select(User).ext(PostSelectClause()).where(User.name == "x")
+        self.assert_compile(
+            stmt,
+            "SELECT POST SELECT KEYWORD users.id, users.name "
+            "FROM users WHERE users.name = :name_1",
+        )
+
+    def test_select_pre_columns_clause(self):
+        User = self.classes.User
+
+        stmt = select(User).ext(PreColumnsClause()).where(User.name == "x")
+        self.assert_compile(
+            stmt,
+            "SELECT PRE COLUMNS users.id, users.name FROM users "
+            "WHERE users.name = :name_1",
+        )
+
+    def test_select_post_criteria_clause(self):
+        User = self.classes.User
+
+        stmt = select(User).ext(PostCriteriaClause()).where(User.name == "x")
+        self.assert_compile(
+            stmt,
+            "SELECT users.id, users.name FROM users "
+            "WHERE users.name = :name_1 POST CRITERIA",
+        )
+
+    def test_select_post_criteria_clause_multiple(self):
+        User = self.classes.User
+
+        stmt = (
+            select(User)
+            .ext(PostCriteriaClause())
+            .ext(PostCriteriaClause2())
+            .where(User.name == "x")
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT users.id, users.name FROM users "
+            "WHERE users.name = :name_1 POST CRITERIA 2 POST CRITERIA 2",
+        )
+
+    def test_select_post_select_body(self):
+        User = self.classes.User
+
+        stmt = select(User).ext(PostBodyClause()).where(User.name == "x")
+
+        self.assert_compile(
+            stmt,
+            "SELECT users.id, users.name FROM users "
+            "WHERE users.name = :name_1 POST SELECT BODY",
+        )
+
+    def test_insert_post_values(self):
+        User = self.classes.User
+
+        self.assert_compile(
+            insert(User).ext(PostValuesClause()),
+            "INSERT INTO users (id, name) VALUES (:id, :name) POST VALUES",
+        )
+
+    def test_update_post_criteria(self):
+        User = self.classes.User
+
+        self.assert_compile(
+            update(User).ext(PostCriteriaClause()).where(User.name == "hi"),
+            "UPDATE users SET id=:id, name=:name "
+            "WHERE users.name = :name_1 POST CRITERIA",
+        )
+
+    @testing.combinations(
+        (lambda User: select(1).ext(PostCriteriaClauseCols(User.id)), True),
+        (
+            lambda User: select(1).ext(PostCriteriaClauseColsNoProp(User.id)),
+            False,
+        ),
+        (
+            lambda User, users: users.update().ext(
+                PostCriteriaClauseCols(User.id)
+            ),
+            True,
+        ),
+        (
+            lambda User, users: users.delete().ext(
+                PostCriteriaClauseCols(User.id)
+            ),
+            True,
+        ),
+        (lambda User, users: users.delete(), False),
+    )
+    def test_propagate_attrs(self, stmt, expected):
+        User = self.classes.User
+        user_table = self.tables.users
+
+        stmt = testing.resolve_lambda(stmt, User=User, users=user_table)
+
+        if expected:
+            eq_(
+                stmt._propagate_attrs,
+                {
+                    "compile_state_plugin": "orm",
+                    "plugin_subject": inspect(User),
+                },
+            )
+        else:
+            eq_(stmt._propagate_attrs, {})
index 5c7c5053e963e086d90d7c0bbd40a7bb0c0f4503..d499609b49595ab764b9a1aaefae15aee0649b88 100644 (file)
@@ -9,6 +9,7 @@ from sqlalchemy import case
 from sqlalchemy import cast
 from sqlalchemy import Column
 from sqlalchemy import column
+from sqlalchemy import DateTime
 from sqlalchemy import dialects
 from sqlalchemy import exists
 from sqlalchemy import extract
@@ -46,15 +47,19 @@ from sqlalchemy.sql import visitors
 from sqlalchemy.sql.annotation import Annotated
 from sqlalchemy.sql.base import HasCacheKey
 from sqlalchemy.sql.base import SingletonConstant
+from sqlalchemy.sql.base import SyntaxExtension
 from sqlalchemy.sql.elements import _label_reference
 from sqlalchemy.sql.elements import _textual_label_reference
 from sqlalchemy.sql.elements import BindParameter
 from sqlalchemy.sql.elements import ClauseElement
 from sqlalchemy.sql.elements import ClauseList
 from sqlalchemy.sql.elements import CollationClause
+from sqlalchemy.sql.elements import DQLDMLClauseElement
+from sqlalchemy.sql.elements import ElementList
 from sqlalchemy.sql.elements import Immutable
 from sqlalchemy.sql.elements import Null
 from sqlalchemy.sql.elements import Slice
+from sqlalchemy.sql.elements import TypeClause
 from sqlalchemy.sql.elements import UnaryExpression
 from sqlalchemy.sql.functions import FunctionElement
 from sqlalchemy.sql.functions import GenericFunction
@@ -190,6 +195,15 @@ class CoreFixtures:
             _label_reference(table_a.c.a.desc()),
             _label_reference(table_a.c.a.asc()),
         ),
+        lambda: (
+            TypeClause(String(50)),
+            TypeClause(DateTime()),
+        ),
+        lambda: (
+            table_a.c.a,
+            ElementList([table_a.c.a]),
+            ElementList([table_a.c.a, table_a.c.b]),
+        ),
         lambda: (_textual_label_reference("a"), _textual_label_reference("b")),
         lambda: (
             text("select a, b from table").columns(a=Integer, b=String),
@@ -987,15 +1001,15 @@ class CoreFixtures:
 
     def _statements_w_context_options_fixtures():
         return [
-            select(table_a)._add_context_option(opt1, True),
-            select(table_a)._add_context_option(opt1, 5),
+            select(table_a)._add_compile_state_func(opt1, True),
+            select(table_a)._add_compile_state_func(opt1, 5),
             select(table_a)
-            ._add_context_option(opt1, True)
-            ._add_context_option(opt2, True),
+            ._add_compile_state_func(opt1, True)
+            ._add_compile_state_func(opt2, True),
             select(table_a)
-            ._add_context_option(opt1, True)
-            ._add_context_option(opt2, 5),
-            select(table_a)._add_context_option(opt3, True),
+            ._add_compile_state_func(opt1, True)
+            ._add_compile_state_func(opt2, 5),
+            select(table_a)._add_compile_state_func(opt3, True),
         ]
 
     fixtures.append(_statements_w_context_options_fixtures)
@@ -1289,7 +1303,7 @@ class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase):
             # a typed column expression, so this is fine
             return (column("x", Integer).in_(elements),)
 
-        self._run_cache_key_fixture(fixture, False)
+        self._run_cache_key_fixture(fixture, compare_values=False)
 
     def test_cache_key(self):
         for fixtures_, compare_values in [
@@ -1298,7 +1312,9 @@ class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase):
             (self.type_cache_key_fixtures, False),
         ]:
             for fixture in fixtures_:
-                self._run_cache_key_fixture(fixture, compare_values)
+                self._run_cache_key_fixture(
+                    fixture, compare_values=compare_values
+                )
 
     def test_cache_key_equal(self):
         for fixture in self.equal_fixtures:
@@ -1313,7 +1329,7 @@ class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase):
 
         self._run_cache_key_fixture(
             fixture,
-            True,
+            compare_values=True,
         )
 
     def test_bindparam_subclass_nocache(self):
@@ -1336,7 +1352,7 @@ class CacheKeyTest(fixtures.CacheKeyFixture, CoreFixtures, fixtures.TestBase):
                 _literal_bindparam(None),
             )
 
-        self._run_cache_key_fixture(fixture, True)
+        self._run_cache_key_fixture(fixture, compare_values=True)
 
     def test_cache_key_unknown_traverse(self):
         class Foobar1(ClauseElement):
@@ -1548,7 +1564,7 @@ class HasCacheKeySubclass(fixtures.TestBase):
         ),
         "FromStatement": (
             {"_raw_columns", "_with_options", "element"}
-            | {"_propagate_attrs", "_with_context_options"},
+            | {"_propagate_attrs", "_compile_state_funcs"},
             {"element", "entities"},
         ),
         "FunctionAsBinary": (
@@ -1604,7 +1620,7 @@ class HasCacheKeySubclass(fixtures.TestBase):
                 "_hints",
                 "_independent_ctes",
                 "_distinct_on",
-                "_with_context_options",
+                "_compile_state_funcs",
                 "_setup_joins",
                 "_suffixes",
                 "_memoized_select_entities",
@@ -1619,6 +1635,10 @@ class HasCacheKeySubclass(fixtures.TestBase):
                 "_annotations",
                 "_fetch_clause_options",
                 "_from_obj",
+                "_post_select_clause",
+                "_post_body_clause",
+                "_post_criteria_clause",
+                "_pre_columns_clause",
             },
             {"entities"},
         ),
@@ -1658,7 +1678,12 @@ class HasCacheKeySubclass(fixtures.TestBase):
 
     @testing.combinations(
         *all_hascachekey_subclasses(
-            ignore_subclasses=[Annotated, NoInit, SingletonConstant]
+            ignore_subclasses=[
+                Annotated,
+                NoInit,
+                SingletonConstant,
+                SyntaxExtension,
+            ]
         )
     )
     def test_init_args_in_traversal(self, cls: type):
@@ -1705,7 +1730,15 @@ class CompareAndCopyTest(CoreFixtures, fixtures.TestBase):
             if "orm" not in cls.__module__
             and "compiler" not in cls.__module__
             and "dialects" not in cls.__module__
-            and issubclass(cls, (ColumnElement, Selectable, LambdaElement))
+            and issubclass(
+                cls,
+                (
+                    ColumnElement,
+                    Selectable,
+                    LambdaElement,
+                    DQLDMLClauseElement,
+                ),
+            )
         )
 
         for fixture in self.fixtures + self.dont_compare_values_fixtures:
diff --git a/test/sql/test_syntax_extensions.py b/test/sql/test_syntax_extensions.py
new file mode 100644 (file)
index 0000000..0279f44
--- /dev/null
@@ -0,0 +1,324 @@
+from sqlalchemy import Column
+from sqlalchemy import column
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import select
+from sqlalchemy import Table
+from sqlalchemy import table
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.sql import ClauseElement
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import roles
+from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql.base import SyntaxExtension
+from sqlalchemy.sql.dml import Delete
+from sqlalchemy.sql.dml import Update
+from sqlalchemy.sql.visitors import _TraverseInternalsType
+from sqlalchemy.sql.visitors import InternalTraversal
+from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
+
+
+class PostSelectClause(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            lambda existing: [*existing, self],
+            "post_select",
+        )
+
+
+class PreColumnsClause(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            lambda existing: [*existing, self],
+            "pre_columns",
+        )
+
+
+class PostCriteriaClause(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            lambda existing: [*existing, self],
+            "post_criteria",
+        )
+
+    def apply_to_update(self, update_stmt: Update) -> None:
+        update_stmt.apply_syntax_extension_point(
+            lambda existing: [self], "post_criteria"
+        )
+
+    def apply_to_delete(self, delete_stmt: Delete) -> None:
+        delete_stmt.apply_syntax_extension_point(
+            lambda existing: [self], "post_criteria"
+        )
+
+
+class PostCriteriaClause2(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            self.append_replacing_same_type,
+            "post_criteria",
+        )
+
+
+class PostCriteriaClause3(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            lambda existing: [self],
+            "post_criteria",
+        )
+
+
+class PostBodyClause(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            lambda existing: [self],
+            "post_body",
+        )
+
+
+class PostValuesClause(SyntaxExtension, ClauseElement):
+    _traverse_internals = []
+
+    def apply_to_insert(self, insert_stmt):
+        insert_stmt.apply_syntax_extension_point(
+            lambda existing: [self],
+            "post_values",
+        )
+
+
+class ColumnExpressionExt(SyntaxExtension, ClauseElement):
+    _traverse_internals: _TraverseInternalsType = [
+        ("_exprs", InternalTraversal.dp_clauseelement_tuple),
+    ]
+
+    def __init__(self, *exprs):
+        self._exprs = tuple(
+            coercions.expect(roles.ByOfRole, e, apply_propagate_attrs=self)
+            for e in exprs
+        )
+
+    def apply_to_select(self, select_stmt):
+        select_stmt.apply_syntax_extension_point(
+            lambda existing: [*existing, self],
+            "post_select",
+        )
+
+
+@compiles(PostSelectClause)
+def _compile_psk(element, compiler, **kw):
+    return "POST SELECT KEYWORD"
+
+
+@compiles(PreColumnsClause)
+def _compile_pcc(element, compiler, **kw):
+    return "PRE COLUMNS"
+
+
+@compiles(PostCriteriaClause)
+def _compile_psc(element, compiler, **kw):
+    return "POST CRITERIA"
+
+
+@compiles(PostCriteriaClause2)
+def _compile_psc2(element, compiler, **kw):
+    return "2 POST CRITERIA 2"
+
+
+@compiles(PostCriteriaClause3)
+def _compile_psc3(element, compiler, **kw):
+    return "3 POST CRITERIA 3"
+
+
+@compiles(PostBodyClause)
+def _compile_psb(element, compiler, **kw):
+    return "POST SELECT BODY"
+
+
+@compiles(PostValuesClause)
+def _compile_pvc(element, compiler, **kw):
+    return "POST VALUES"
+
+
+@compiles(ColumnExpressionExt)
+def _compile_cee(element, compiler, **kw):
+    inner = ", ".join(compiler.process(elem, **kw) for elem in element._exprs)
+    return f"COLUMN EXPRESSIONS ({inner})"
+
+
+class TestExtensionPoints(fixtures.TestBase, AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    def test_illegal_section(self):
+        class SomeExtension(SyntaxExtension, ClauseElement):
+            _traverse_internals = []
+
+            def apply_to_select(self, select_stmt):
+                select_stmt.apply_syntax_extension_point(
+                    lambda existing: [self],
+                    "not_present",
+                )
+
+        with expect_raises_message(
+            ValueError,
+            r"Unknown position 'not_present' for <class .*Select'> "
+            "construct; known positions: "
+            "'post_select', 'pre_columns', 'post_criteria', 'post_body'",
+        ):
+            select(column("q")).ext(SomeExtension())
+
+    def test_select_post_select_clause(self):
+        self.assert_compile(
+            select(column("a"), column("b"))
+            .ext(PostSelectClause())
+            .where(column("q") == 5),
+            "SELECT POST SELECT KEYWORD a, b WHERE q = :q_1",
+        )
+
+    def test_select_pre_columns_clause(self):
+        self.assert_compile(
+            select(column("a"), column("b"))
+            .ext(PreColumnsClause())
+            .where(column("q") == 5)
+            .distinct(),
+            "SELECT DISTINCT PRE COLUMNS a, b WHERE q = :q_1",
+        )
+
+    def test_select_post_criteria_clause(self):
+        self.assert_compile(
+            select(column("a"), column("b"))
+            .ext(PostCriteriaClause())
+            .where(column("q") == 5)
+            .having(column("z") == 10)
+            .order_by(column("r")),
+            "SELECT a, b WHERE q = :q_1 HAVING z = :z_1 "
+            "POST CRITERIA ORDER BY r",
+        )
+
+    def test_select_post_criteria_clause_multiple(self):
+        self.assert_compile(
+            select(column("a"), column("b"))
+            .ext(PostCriteriaClause())
+            .ext(PostCriteriaClause2())
+            .where(column("q") == 5)
+            .having(column("z") == 10)
+            .order_by(column("r")),
+            "SELECT a, b WHERE q = :q_1 HAVING z = :z_1 "
+            "POST CRITERIA 2 POST CRITERIA 2 ORDER BY r",
+        )
+
+    def test_select_post_criteria_clause_multiple2(self):
+        stmt = (
+            select(column("a"), column("b"))
+            .ext(PostCriteriaClause())
+            .ext(PostCriteriaClause())
+            .ext(PostCriteriaClause2())
+            .ext(PostCriteriaClause2())
+            .where(column("q") == 5)
+            .having(column("z") == 10)
+            .order_by(column("r"))
+        )
+        # PostCriteriaClause2 is here only once
+        self.assert_compile(
+            stmt,
+            "SELECT a, b WHERE q = :q_1 HAVING z = :z_1 "
+            "POST CRITERIA POST CRITERIA 2 POST CRITERIA 2 ORDER BY r",
+        )
+        # now there is only PostCriteriaClause3
+        self.assert_compile(
+            stmt.ext(PostCriteriaClause3()),
+            "SELECT a, b WHERE q = :q_1 HAVING z = :z_1 "
+            "3 POST CRITERIA 3 ORDER BY r",
+        )
+
+    def test_select_post_select_body(self):
+        self.assert_compile(
+            select(column("a"), column("b"))
+            .ext(PostBodyClause())
+            .where(column("q") == 5)
+            .having(column("z") == 10)
+            .order_by(column("r"))
+            .limit(15),
+            "SELECT a, b WHERE q = :q_1 HAVING z = :z_1 "
+            "ORDER BY r LIMIT :param_1 POST SELECT BODY",
+        )
+
+    def test_insert_post_values(self):
+        t = table("t", column("a"), column("b"))
+        self.assert_compile(
+            t.insert().ext(PostValuesClause()),
+            "INSERT INTO t (a, b) VALUES (:a, :b) POST VALUES",
+        )
+
+    def test_update_post_criteria(self):
+        t = table("t", column("a"), column("b"))
+        self.assert_compile(
+            t.update().ext(PostCriteriaClause()).where(t.c.a == "hi"),
+            "UPDATE t SET a=:a, b=:b WHERE t.a = :a_1 POST CRITERIA",
+        )
+
+    def test_delete_post_criteria(self):
+        t = table("t", column("a"), column("b"))
+        self.assert_compile(
+            t.delete().ext(PostCriteriaClause()).where(t.c.a == "hi"),
+            "DELETE FROM t WHERE t.a = :a_1 POST CRITERIA",
+        )
+
+
+class TestExpressionExtensions(
+    fixtures.CacheKeyFixture, fixtures.TestBase, AssertsCompiledSQL
+):
+    __dialect__ = "default"
+
+    def test_render(self):
+        t = Table(
+            "t1", MetaData(), Column("c1", Integer), Column("c2", Integer)
+        )
+
+        stmt = select(t).ext(ColumnExpressionExt(t.c.c1, t.c.c2))
+        self.assert_compile(
+            stmt,
+            "SELECT COLUMN EXPRESSIONS (t1.c1, t1.c2) t1.c1, t1.c2 FROM t1",
+        )
+
+    def test_adaptation(self):
+        t = Table(
+            "t1", MetaData(), Column("c1", Integer), Column("c2", Integer)
+        )
+
+        s1 = select(t).subquery()
+        s2 = select(t).ext(ColumnExpressionExt(t.c.c1, t.c.c2))
+        s3 = sql_util.ClauseAdapter(s1).traverse(s2)
+
+        self.assert_compile(
+            s3,
+            "SELECT COLUMN EXPRESSIONS (anon_1.c1, anon_1.c2) "
+            "anon_1.c1, anon_1.c2 FROM "
+            "(SELECT t1.c1 AS c1, t1.c2 AS c2 FROM t1) AS anon_1",
+        )
+
+    def test_compare(self):
+        t = Table(
+            "t1", MetaData(), Column("c1", Integer), Column("c2", Integer)
+        )
+
+        self._run_compare_fixture(
+            lambda: (
+                select(t).ext(ColumnExpressionExt(t.c.c1, t.c.c2)),
+                select(t).ext(ColumnExpressionExt(t.c.c1)),
+                select(t),
+            )
+        )