From f522e43cc7c31d3aaffb4e126d2d06a719e0d157 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 23 Jun 2024 15:40:31 -0400 Subject: [PATCH] use a new ClauseElement for Over.range_ / Over.rows Enhanced the caching structure of the :paramref:`.over.rows` and :paramref:`.over.range` so that different numerical values for the rows / range fields are cached on the same cache key, to the extent that the underlying SQL does not actually change (i.e. "unbounded", "current row", negative/positive status will still change the cache key). This prevents the use of many different numerical range/rows value for a query that is otherwise identical from filling up the SQL cache. Note that the semi-private compiler method ``_format_frame_clause()`` is removed by this fix, replaced with a new method ``visit_frame_clause()``. Third party dialects which may have referred to this method will need to change the name and revise the approach to rendering the correct SQL for that dialect. This patch introduces a new ClauseElement called _FrameClause which stores the integer range values separately and within cache-compatible BindParameter objects from the "type" which can be unbounded, current, preceding, or following, represented by a _FrameClauseType enum. The negative sign is also stripped from the integer and represented within the _FrameClauseType. Tests from #11514 are adapted to include a test for SQL Server's "literal_execute" flag taking effect so that literal numeric values aren't stored in the cache. Fixes: #11515 Change-Id: I8aad368ffef9f06cb5c3f8c4e971fadef029ffd5 --- doc/build/changelog/unreleased_21/11515.rst | 18 +++ lib/sqlalchemy/dialects/mssql/base.py | 4 +- lib/sqlalchemy/sql/compiler.py | 84 +++++------ lib/sqlalchemy/sql/elements.py | 150 +++++++++++--------- lib/sqlalchemy/testing/suite/test_select.py | 40 ++++-- test/sql/test_compare.py | 26 ++++ 6 files changed, 193 insertions(+), 129 deletions(-) create mode 100644 doc/build/changelog/unreleased_21/11515.rst diff --git a/doc/build/changelog/unreleased_21/11515.rst b/doc/build/changelog/unreleased_21/11515.rst new file mode 100644 index 0000000000..507ab3f814 --- /dev/null +++ b/doc/build/changelog/unreleased_21/11515.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: bug, sql + :tickets: 11515 + + Enhanced the caching structure of the :paramref:`.over.rows` and + :paramref:`.over.range` so that different numerical values for the rows / + range fields are cached on the same cache key, to the extent that the + underlying SQL does not actually change (i.e. "unbounded", "current row", + negative/positive status will still change the cache key). This prevents + the use of many different numerical range/rows value for a query that is + otherwise identical from filling up the SQL cache. + + Note that the semi-private compiler method ``_format_frame_clause()`` + is removed by this fix, replaced with a new method + ``visit_frame_clause()``. Third party dialects which may have referred + to this method will need to change the name and revise the approach to + rendering the correct SQL for that dialect. + diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index ddee9a5a73..57b273e1a8 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1988,9 +1988,9 @@ class MSSQLCompiler(compiler.SQLCompiler): self.tablealiases = {} super().__init__(*args, **kwargs) - def _format_frame_clause(self, range_, **kw): + def visit_frame_clause(self, frameclause, **kw): kw["literal_execute"] = True - return super()._format_frame_clause(range_, **kw) + return super().visit_frame_clause(frameclause, **kw) def _with_legacy_schema_aliasing(fn): def decorate(self, *arg, **kw): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 88e14645bb..18baf0f8e7 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2836,58 +2836,44 @@ class SQLCompiler(Compiled): match.group(2) if match else "", ) - def _format_frame_clause(self, range_, **kw): - return "%s AND %s" % ( - ( - "UNBOUNDED PRECEDING" - if range_[0] is elements.RANGE_UNBOUNDED - else ( - "CURRENT ROW" - if range_[0] is elements.RANGE_CURRENT - else ( - "%s PRECEDING" - % ( - self.process( - elements.literal(abs(range_[0])), **kw - ), - ) - if range_[0] < 0 - else "%s FOLLOWING" - % (self.process(elements.literal(range_[0]), **kw),) - ) - ) - ), - ( - "UNBOUNDED FOLLOWING" - if range_[1] is elements.RANGE_UNBOUNDED - else ( - "CURRENT ROW" - if range_[1] is elements.RANGE_CURRENT - else ( - "%s PRECEDING" - % ( - self.process( - elements.literal(abs(range_[1])), **kw - ), - ) - if range_[1] < 0 - else "%s FOLLOWING" - % (self.process(elements.literal(range_[1]), **kw),) - ) - ) - ), - ) + def visit_frame_clause(self, frameclause, **kw): + + if frameclause.lower_type is elements._FrameClauseType.RANGE_UNBOUNDED: + left = "UNBOUNDED PRECEDING" + elif frameclause.lower_type is elements._FrameClauseType.RANGE_CURRENT: + left = "CURRENT ROW" + else: + val = self.process(frameclause.lower_integer_bind, **kw) + if ( + frameclause.lower_type + is elements._FrameClauseType.RANGE_PRECEDING + ): + left = f"{val} PRECEDING" + else: + left = f"{val} FOLLOWING" + + if frameclause.upper_type is elements._FrameClauseType.RANGE_UNBOUNDED: + right = "UNBOUNDED FOLLOWING" + elif frameclause.upper_type is elements._FrameClauseType.RANGE_CURRENT: + right = "CURRENT ROW" + else: + val = self.process(frameclause.upper_integer_bind, **kw) + if ( + frameclause.upper_type + is elements._FrameClauseType.RANGE_PRECEDING + ): + right = f"{val} PRECEDING" + else: + right = f"{val} FOLLOWING" + + return f"{left} AND {right}" def visit_over(self, over, **kwargs): text = over.element._compiler_dispatch(self, **kwargs) - if over.range_: - range_ = "RANGE BETWEEN %s" % self._format_frame_clause( - over.range_, **kwargs - ) - elif over.rows: - range_ = "ROWS BETWEEN %s" % self._format_frame_clause( - over.rows, **kwargs - ) + if over.range_ is not None: + range_ = f"RANGE BETWEEN {self.process(over.range_, **kwargs)}" + elif over.rows is not None: + range_ = f"ROWS BETWEEN {self.process(over.rows, **kwargs)}" else: range_ = None diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 80e98c1e19..a4841e07f3 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -4149,17 +4149,6 @@ class _OverrideBinds(Grouping[_T]): return ck -class _OverRange(Enum): - RANGE_UNBOUNDED = 0 - RANGE_CURRENT = 1 - - -RANGE_UNBOUNDED = _OverRange.RANGE_UNBOUNDED -RANGE_CURRENT = _OverRange.RANGE_CURRENT - -_IntOrRange = Union[int, _OverRange] - - class Over(ColumnElement[_T]): """Represent an OVER clause. @@ -4176,8 +4165,8 @@ class Over(ColumnElement[_T]): ("element", InternalTraversal.dp_clauseelement), ("order_by", InternalTraversal.dp_clauseelement), ("partition_by", InternalTraversal.dp_clauseelement), - ("range_", InternalTraversal.dp_plain_obj), - ("rows", InternalTraversal.dp_plain_obj), + ("range_", InternalTraversal.dp_clauseelement), + ("rows", InternalTraversal.dp_clauseelement), ] order_by: Optional[ClauseList] = None @@ -4187,8 +4176,8 @@ class Over(ColumnElement[_T]): """The underlying expression object to which this :class:`.Over` object refers.""" - range_: Optional[typing_Tuple[_IntOrRange, _IntOrRange]] - rows: Optional[typing_Tuple[_IntOrRange, _IntOrRange]] + range_: Optional[_FrameClause] + rows: Optional[_FrameClause] def __init__( self, @@ -4210,7 +4199,7 @@ class Over(ColumnElement[_T]): ) if range_: - self.range_ = self._interpret_range(range_) + self.range_ = _FrameClause(range_) if rows: raise exc.ArgumentError( "'range_' and 'rows' are mutually exclusive" @@ -4218,81 +4207,112 @@ class Over(ColumnElement[_T]): else: self.rows = None elif rows: - self.rows = self._interpret_range(rows) + self.rows = _FrameClause(rows) self.range_ = None else: self.rows = self.range_ = None - def __reduce__(self): - return self.__class__, ( - self.element, - self.partition_by, - self.order_by, - self.range_, - self.rows, + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + return self.element.type + + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + return list( + itertools.chain( + *[ + c._from_objects + for c in (self.element, self.partition_by, self.order_by) + if c is not None + ] + ) ) - def _interpret_range( - self, - range_: typing_Tuple[Optional[_IntOrRange], Optional[_IntOrRange]], - ) -> typing_Tuple[_IntOrRange, _IntOrRange]: - if not isinstance(range_, tuple) or len(range_) != 2: - raise exc.ArgumentError("2-tuple expected for range/rows") - r0, r1 = range_ +class _FrameClauseType(Enum): + RANGE_UNBOUNDED = 0 + RANGE_CURRENT = 1 + RANGE_PRECEDING = 2 + RANGE_FOLLOWING = 3 + + +class _FrameClause(ClauseElement): + """indicate the 'rows' or 'range' field of a window function, e.g. using + :class:`.Over`. - lower: _IntOrRange - upper: _IntOrRange + .. versionadded:: 2.1 + + """ + + __visit_name__ = "frame_clause" + + _traverse_internals: _TraverseInternalsType = [ + ("lower_integer_bind", InternalTraversal.dp_clauseelement), + ("upper_integer_bind", InternalTraversal.dp_clauseelement), + ("lower_type", InternalTraversal.dp_plain_obj), + ("upper_type", InternalTraversal.dp_plain_obj), + ] + + def __init__( + self, + range_: typing_Tuple[Optional[int], Optional[int]], + ): + try: + r0, r1 = range_ + except (ValueError, TypeError) as ve: + raise exc.ArgumentError("2-tuple expected for range/rows") from ve if r0 is None: - lower = RANGE_UNBOUNDED - elif isinstance(r0, _OverRange): - lower = r0 + self.lower_type = _FrameClauseType.RANGE_UNBOUNDED + self.lower_integer_bind = None else: try: - lower = int(r0) + lower_integer = int(r0) except ValueError as err: raise exc.ArgumentError( "Integer or None expected for range value" ) from err else: - if lower == 0: - lower = RANGE_CURRENT + if lower_integer == 0: + self.lower_type = _FrameClauseType.RANGE_CURRENT + self.lower_integer_bind = None + elif lower_integer < 0: + self.lower_type = _FrameClauseType.RANGE_PRECEDING + self.lower_integer_bind = literal( + abs(lower_integer), type_api.INTEGERTYPE + ) + else: + self.lower_type = _FrameClauseType.RANGE_FOLLOWING + self.lower_integer_bind = literal( + lower_integer, type_api.INTEGERTYPE + ) if r1 is None: - upper = RANGE_UNBOUNDED - elif isinstance(r1, _OverRange): - upper = r1 + self.upper_type = _FrameClauseType.RANGE_UNBOUNDED + self.upper_integer_bind = None else: try: - upper = int(r1) + upper_integer = int(r1) except ValueError as err: raise exc.ArgumentError( "Integer or None expected for range value" ) from err else: - if upper == 0: - upper = RANGE_CURRENT - - return lower, upper - - if not TYPE_CHECKING: - - @util.memoized_property - def type(self) -> TypeEngine[_T]: # noqa: A001 - return self.element.type - - @util.ro_non_memoized_property - def _from_objects(self) -> List[FromClause]: - return list( - itertools.chain( - *[ - c._from_objects - for c in (self.element, self.partition_by, self.order_by) - if c is not None - ] - ) - ) + if upper_integer == 0: + self.upper_type = _FrameClauseType.RANGE_CURRENT + self.upper_integer_bind = None + elif upper_integer < 0: + self.upper_type = _FrameClauseType.RANGE_PRECEDING + self.upper_integer_bind = literal( + abs(upper_integer), type_api.INTEGERTYPE + ) + else: + self.upper_type = _FrameClauseType.RANGE_FOLLOWING + self.upper_integer_bind = literal( + upper_integer, type_api.INTEGERTYPE + ) class WithinGroup(ColumnElement[_T]): diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 9f2a08d151..882ca45967 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -1922,18 +1922,32 @@ class WindowFunctionTest(fixtures.TablesTest): eq_(rows, [(95,) for i in range(19)]) - def test_window_rows_between(self, connection): + def test_window_rows_between_w_caching(self, connection): some_table = self.tables.some_table - # note the rows are part of the cache key right now, not handled - # as binds. this is issue #11515 - rows = connection.execute( - select( - func.max(some_table.c.col2).over( - order_by=[some_table.c.col1], - rows=(-5, 0), - ) - ) - ).all() - - eq_(rows, [(i,) for i in range(5, 250, 5)]) + # this tests that dialects such as SQL Server which require literal + # rendering of ROWS BETWEEN and RANGE BETWEEN numerical values make + # use of literal_execute, for post-cache rendering of integer values, + # and not literal_binds which would include the integer values in the + # cached string (caching overall fixed in #11515) + for i in range(3): + for rows, expected in [ + ( + (5, 20), + list(range(105, 245, 5)) + ([245] * 16) + [None] * 5, + ), + ( + (20, 30), + list(range(155, 245, 5)) + ([245] * 11) + [None] * 20, + ), + ]: + result_rows = connection.execute( + select( + func.max(some_table.c.col2).over( + order_by=[some_table.c.col1], + rows=rows, + ) + ) + ).all() + + eq_(result_rows, [(i,) for i in expected]) diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index c1f6e7f113..d8947ab67b 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -1124,6 +1124,32 @@ class CoreFixtures: dont_compare_values_fixtures.append(_lambda_fixtures) + def _numeric_agnostic_window_functions(): + return ( + func.row_number().over( + order_by=table_a.c.a, + range_=(random.randint(50, 60), random.randint(60, 70)), + ), + func.row_number().over( + order_by=table_a.c.a, + range_=(random.randint(-40, -20), random.randint(60, 70)), + ), + func.row_number().over( + order_by=table_a.c.a, + rows=(random.randint(-40, -20), random.randint(60, 70)), + ), + func.row_number().over( + order_by=table_a.c.a, + range_=(None, random.randint(60, 70)), + ), + func.row_number().over( + order_by=table_a.c.a, + range_=(random.randint(50, 60), None), + ), + ) + + dont_compare_values_fixtures.append(_numeric_agnostic_window_functions) + # like fixture but returns at least two objects that compare equally equal_fixtures = [ lambda: ( -- 2.47.2