]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use a new ClauseElement for Over.range_ / Over.rows
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 23 Jun 2024 19:40:31 +0000 (15:40 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 24 Jun 2024 19:21:18 +0000 (15:21 -0400)
Enhanced the caching structure of the :paramref:`.over.rows` and
:paramref:`.over.range` so that different numerical values for the rows /
range fields are cached on the same cache key, to the extent that the
underlying SQL does not actually change (i.e. "unbounded", "current row",
negative/positive status will still change the cache key).  This prevents
the use of many different numerical range/rows value for a query that is
otherwise identical from filling up the SQL cache.

Note that the semi-private compiler method ``_format_frame_clause()``
is removed by this fix, replaced with a new method
``visit_frame_clause()``.  Third party dialects which may have referred
to this method will need to change the name and revise the approach to
rendering the correct SQL for that dialect.

This patch introduces a new ClauseElement called _FrameClause which
stores the integer range values separately and within cache-compatible
BindParameter objects from the "type" which
can be unbounded, current, preceding, or following, represented by
a _FrameClauseType enum.    The negative
sign is also stripped from the integer and represented within the
_FrameClauseType.  Tests from #11514 are adapted to include
a test for SQL Server's "literal_execute" flag taking effect so
that literal numeric values aren't stored in the cache.

Fixes: #11515
Change-Id: I8aad368ffef9f06cb5c3f8c4e971fadef029ffd5

doc/build/changelog/unreleased_21/11515.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/testing/suite/test_select.py
test/sql/test_compare.py

diff --git a/doc/build/changelog/unreleased_21/11515.rst b/doc/build/changelog/unreleased_21/11515.rst
new file mode 100644 (file)
index 0000000..507ab3f
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 11515
+
+    Enhanced the caching structure of the :paramref:`.over.rows` and
+    :paramref:`.over.range` so that different numerical values for the rows /
+    range fields are cached on the same cache key, to the extent that the
+    underlying SQL does not actually change (i.e. "unbounded", "current row",
+    negative/positive status will still change the cache key).  This prevents
+    the use of many different numerical range/rows value for a query that is
+    otherwise identical from filling up the SQL cache.
+
+    Note that the semi-private compiler method ``_format_frame_clause()``
+    is removed by this fix, replaced with a new method
+    ``visit_frame_clause()``.  Third party dialects which may have referred
+    to this method will need to change the name and revise the approach to
+    rendering the correct SQL for that dialect.
+
index ddee9a5a7398a9c809e36bdb90265deebf40925a..57b273e1a8e0b89f33a7b404805169a313f3f355 100644 (file)
@@ -1988,9 +1988,9 @@ class MSSQLCompiler(compiler.SQLCompiler):
         self.tablealiases = {}
         super().__init__(*args, **kwargs)
 
-    def _format_frame_clause(self, range_, **kw):
+    def visit_frame_clause(self, frameclause, **kw):
         kw["literal_execute"] = True
-        return super()._format_frame_clause(range_, **kw)
+        return super().visit_frame_clause(frameclause, **kw)
 
     def _with_legacy_schema_aliasing(fn):
         def decorate(self, *arg, **kw):
index 88e14645bbc62521354d3af03042b70e5779f6cd..18baf0f8e7fa551aa5a85fa016828babac184ec3 100644 (file)
@@ -2836,58 +2836,44 @@ class SQLCompiler(Compiled):
             match.group(2) if match else "",
         )
 
-    def _format_frame_clause(self, range_, **kw):
-        return "%s AND %s" % (
-            (
-                "UNBOUNDED PRECEDING"
-                if range_[0] is elements.RANGE_UNBOUNDED
-                else (
-                    "CURRENT ROW"
-                    if range_[0] is elements.RANGE_CURRENT
-                    else (
-                        "%s PRECEDING"
-                        % (
-                            self.process(
-                                elements.literal(abs(range_[0])), **kw
-                            ),
-                        )
-                        if range_[0] < 0
-                        else "%s FOLLOWING"
-                        % (self.process(elements.literal(range_[0]), **kw),)
-                    )
-                )
-            ),
-            (
-                "UNBOUNDED FOLLOWING"
-                if range_[1] is elements.RANGE_UNBOUNDED
-                else (
-                    "CURRENT ROW"
-                    if range_[1] is elements.RANGE_CURRENT
-                    else (
-                        "%s PRECEDING"
-                        % (
-                            self.process(
-                                elements.literal(abs(range_[1])), **kw
-                            ),
-                        )
-                        if range_[1] < 0
-                        else "%s FOLLOWING"
-                        % (self.process(elements.literal(range_[1]), **kw),)
-                    )
-                )
-            ),
-        )
+    def visit_frame_clause(self, frameclause, **kw):
+
+        if frameclause.lower_type is elements._FrameClauseType.RANGE_UNBOUNDED:
+            left = "UNBOUNDED PRECEDING"
+        elif frameclause.lower_type is elements._FrameClauseType.RANGE_CURRENT:
+            left = "CURRENT ROW"
+        else:
+            val = self.process(frameclause.lower_integer_bind, **kw)
+            if (
+                frameclause.lower_type
+                is elements._FrameClauseType.RANGE_PRECEDING
+            ):
+                left = f"{val} PRECEDING"
+            else:
+                left = f"{val} FOLLOWING"
+
+        if frameclause.upper_type is elements._FrameClauseType.RANGE_UNBOUNDED:
+            right = "UNBOUNDED FOLLOWING"
+        elif frameclause.upper_type is elements._FrameClauseType.RANGE_CURRENT:
+            right = "CURRENT ROW"
+        else:
+            val = self.process(frameclause.upper_integer_bind, **kw)
+            if (
+                frameclause.upper_type
+                is elements._FrameClauseType.RANGE_PRECEDING
+            ):
+                right = f"{val} PRECEDING"
+            else:
+                right = f"{val} FOLLOWING"
+
+        return f"{left} AND {right}"
 
     def visit_over(self, over, **kwargs):
         text = over.element._compiler_dispatch(self, **kwargs)
-        if over.range_:
-            range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
-                over.range_, **kwargs
-            )
-        elif over.rows:
-            range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
-                over.rows, **kwargs
-            )
+        if over.range_ is not None:
+            range_ = f"RANGE BETWEEN {self.process(over.range_, **kwargs)}"
+        elif over.rows is not None:
+            range_ = f"ROWS BETWEEN {self.process(over.rows, **kwargs)}"
         else:
             range_ = None
 
index 80e98c1e19ce2b435b57d3604207ee8134b35047..a4841e07f3dc14c977a39ac58ffdbacb5c5862ef 100644 (file)
@@ -4149,17 +4149,6 @@ class _OverrideBinds(Grouping[_T]):
         return ck
 
 
-class _OverRange(Enum):
-    RANGE_UNBOUNDED = 0
-    RANGE_CURRENT = 1
-
-
-RANGE_UNBOUNDED = _OverRange.RANGE_UNBOUNDED
-RANGE_CURRENT = _OverRange.RANGE_CURRENT
-
-_IntOrRange = Union[int, _OverRange]
-
-
 class Over(ColumnElement[_T]):
     """Represent an OVER clause.
 
@@ -4176,8 +4165,8 @@ class Over(ColumnElement[_T]):
         ("element", InternalTraversal.dp_clauseelement),
         ("order_by", InternalTraversal.dp_clauseelement),
         ("partition_by", InternalTraversal.dp_clauseelement),
-        ("range_", InternalTraversal.dp_plain_obj),
-        ("rows", InternalTraversal.dp_plain_obj),
+        ("range_", InternalTraversal.dp_clauseelement),
+        ("rows", InternalTraversal.dp_clauseelement),
     ]
 
     order_by: Optional[ClauseList] = None
@@ -4187,8 +4176,8 @@ class Over(ColumnElement[_T]):
     """The underlying expression object to which this :class:`.Over`
     object refers."""
 
-    range_: Optional[typing_Tuple[_IntOrRange, _IntOrRange]]
-    rows: Optional[typing_Tuple[_IntOrRange, _IntOrRange]]
+    range_: Optional[_FrameClause]
+    rows: Optional[_FrameClause]
 
     def __init__(
         self,
@@ -4210,7 +4199,7 @@ class Over(ColumnElement[_T]):
             )
 
         if range_:
-            self.range_ = self._interpret_range(range_)
+            self.range_ = _FrameClause(range_)
             if rows:
                 raise exc.ArgumentError(
                     "'range_' and 'rows' are mutually exclusive"
@@ -4218,81 +4207,112 @@ class Over(ColumnElement[_T]):
             else:
                 self.rows = None
         elif rows:
-            self.rows = self._interpret_range(rows)
+            self.rows = _FrameClause(rows)
             self.range_ = None
         else:
             self.rows = self.range_ = None
 
-    def __reduce__(self):
-        return self.__class__, (
-            self.element,
-            self.partition_by,
-            self.order_by,
-            self.range_,
-            self.rows,
+    if not TYPE_CHECKING:
+
+        @util.memoized_property
+        def type(self) -> TypeEngine[_T]:  # noqa: A001
+            return self.element.type
+
+    @util.ro_non_memoized_property
+    def _from_objects(self) -> List[FromClause]:
+        return list(
+            itertools.chain(
+                *[
+                    c._from_objects
+                    for c in (self.element, self.partition_by, self.order_by)
+                    if c is not None
+                ]
+            )
         )
 
-    def _interpret_range(
-        self,
-        range_: typing_Tuple[Optional[_IntOrRange], Optional[_IntOrRange]],
-    ) -> typing_Tuple[_IntOrRange, _IntOrRange]:
-        if not isinstance(range_, tuple) or len(range_) != 2:
-            raise exc.ArgumentError("2-tuple expected for range/rows")
 
-        r0, r1 = range_
+class _FrameClauseType(Enum):
+    RANGE_UNBOUNDED = 0
+    RANGE_CURRENT = 1
+    RANGE_PRECEDING = 2
+    RANGE_FOLLOWING = 3
+
+
+class _FrameClause(ClauseElement):
+    """indicate the 'rows' or 'range' field of a window function, e.g. using
+    :class:`.Over`.
 
-        lower: _IntOrRange
-        upper: _IntOrRange
+    .. versionadded:: 2.1
+
+    """
+
+    __visit_name__ = "frame_clause"
+
+    _traverse_internals: _TraverseInternalsType = [
+        ("lower_integer_bind", InternalTraversal.dp_clauseelement),
+        ("upper_integer_bind", InternalTraversal.dp_clauseelement),
+        ("lower_type", InternalTraversal.dp_plain_obj),
+        ("upper_type", InternalTraversal.dp_plain_obj),
+    ]
+
+    def __init__(
+        self,
+        range_: typing_Tuple[Optional[int], Optional[int]],
+    ):
+        try:
+            r0, r1 = range_
+        except (ValueError, TypeError) as ve:
+            raise exc.ArgumentError("2-tuple expected for range/rows") from ve
 
         if r0 is None:
-            lower = RANGE_UNBOUNDED
-        elif isinstance(r0, _OverRange):
-            lower = r0
+            self.lower_type = _FrameClauseType.RANGE_UNBOUNDED
+            self.lower_integer_bind = None
         else:
             try:
-                lower = int(r0)
+                lower_integer = int(r0)
             except ValueError as err:
                 raise exc.ArgumentError(
                     "Integer or None expected for range value"
                 ) from err
             else:
-                if lower == 0:
-                    lower = RANGE_CURRENT
+                if lower_integer == 0:
+                    self.lower_type = _FrameClauseType.RANGE_CURRENT
+                    self.lower_integer_bind = None
+                elif lower_integer < 0:
+                    self.lower_type = _FrameClauseType.RANGE_PRECEDING
+                    self.lower_integer_bind = literal(
+                        abs(lower_integer), type_api.INTEGERTYPE
+                    )
+                else:
+                    self.lower_type = _FrameClauseType.RANGE_FOLLOWING
+                    self.lower_integer_bind = literal(
+                        lower_integer, type_api.INTEGERTYPE
+                    )
 
         if r1 is None:
-            upper = RANGE_UNBOUNDED
-        elif isinstance(r1, _OverRange):
-            upper = r1
+            self.upper_type = _FrameClauseType.RANGE_UNBOUNDED
+            self.upper_integer_bind = None
         else:
             try:
-                upper = int(r1)
+                upper_integer = int(r1)
             except ValueError as err:
                 raise exc.ArgumentError(
                     "Integer or None expected for range value"
                 ) from err
             else:
-                if upper == 0:
-                    upper = RANGE_CURRENT
-
-        return lower, upper
-
-    if not TYPE_CHECKING:
-
-        @util.memoized_property
-        def type(self) -> TypeEngine[_T]:  # noqa: A001
-            return self.element.type
-
-    @util.ro_non_memoized_property
-    def _from_objects(self) -> List[FromClause]:
-        return list(
-            itertools.chain(
-                *[
-                    c._from_objects
-                    for c in (self.element, self.partition_by, self.order_by)
-                    if c is not None
-                ]
-            )
-        )
+                if upper_integer == 0:
+                    self.upper_type = _FrameClauseType.RANGE_CURRENT
+                    self.upper_integer_bind = None
+                elif upper_integer < 0:
+                    self.upper_type = _FrameClauseType.RANGE_PRECEDING
+                    self.upper_integer_bind = literal(
+                        abs(upper_integer), type_api.INTEGERTYPE
+                    )
+                else:
+                    self.upper_type = _FrameClauseType.RANGE_FOLLOWING
+                    self.upper_integer_bind = literal(
+                        upper_integer, type_api.INTEGERTYPE
+                    )
 
 
 class WithinGroup(ColumnElement[_T]):
index 9f2a08d151a6bf938d2fb901561604e9f9515636..882ca4596786f6c7e592441ede87cdb6d178b351 100644 (file)
@@ -1922,18 +1922,32 @@ class WindowFunctionTest(fixtures.TablesTest):
 
         eq_(rows, [(95,) for i in range(19)])
 
-    def test_window_rows_between(self, connection):
+    def test_window_rows_between_w_caching(self, connection):
         some_table = self.tables.some_table
 
-        # note the rows are part of the cache key right now, not handled
-        # as binds.  this is issue #11515
-        rows = connection.execute(
-            select(
-                func.max(some_table.c.col2).over(
-                    order_by=[some_table.c.col1],
-                    rows=(-5, 0),
-                )
-            )
-        ).all()
-
-        eq_(rows, [(i,) for i in range(5, 250, 5)])
+        # this tests that dialects such as SQL Server which require literal
+        # rendering of ROWS BETWEEN and RANGE BETWEEN numerical values make
+        # use of literal_execute, for post-cache rendering of integer values,
+        # and not literal_binds which would include the integer values in the
+        # cached string (caching overall fixed in #11515)
+        for i in range(3):
+            for rows, expected in [
+                (
+                    (5, 20),
+                    list(range(105, 245, 5)) + ([245] * 16) + [None] * 5,
+                ),
+                (
+                    (20, 30),
+                    list(range(155, 245, 5)) + ([245] * 11) + [None] * 20,
+                ),
+            ]:
+                result_rows = connection.execute(
+                    select(
+                        func.max(some_table.c.col2).over(
+                            order_by=[some_table.c.col1],
+                            rows=rows,
+                        )
+                    )
+                ).all()
+
+                eq_(result_rows, [(i,) for i in expected])
index c1f6e7f11368800ce62130ac4af81c0eb0e9af84..d8947ab67b73fb6d9fbb8da9eb7d5cdf039ccab7 100644 (file)
@@ -1124,6 +1124,32 @@ class CoreFixtures:
 
     dont_compare_values_fixtures.append(_lambda_fixtures)
 
+    def _numeric_agnostic_window_functions():
+        return (
+            func.row_number().over(
+                order_by=table_a.c.a,
+                range_=(random.randint(50, 60), random.randint(60, 70)),
+            ),
+            func.row_number().over(
+                order_by=table_a.c.a,
+                range_=(random.randint(-40, -20), random.randint(60, 70)),
+            ),
+            func.row_number().over(
+                order_by=table_a.c.a,
+                rows=(random.randint(-40, -20), random.randint(60, 70)),
+            ),
+            func.row_number().over(
+                order_by=table_a.c.a,
+                range_=(None, random.randint(60, 70)),
+            ),
+            func.row_number().over(
+                order_by=table_a.c.a,
+                range_=(random.randint(50, 60), None),
+            ),
+        )
+
+    dont_compare_values_fixtures.append(_numeric_agnostic_window_functions)
+
     # like fixture but returns at least two objects that compare equally
     equal_fixtures = [
         lambda: (