From 9f4f84ffdc1be487930b00d0b190bd492d302ca1 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Wed, 29 May 2024 22:03:17 +0200 Subject: [PATCH] Fix Over serialization Fixed issue when serializing an :func:`_sql.over` clause with unbounded range or rows. Fixes: #11422 Change-Id: I52a9f72205fd9c7ef5620596c83551e73d5cee5b --- doc/build/changelog/unreleased_20/11422.rst | 6 ++++ lib/sqlalchemy/sql/elements.py | 32 +++++++++++++------- test/ext/test_serializer.py | 33 ++++++++++++++++++--- 3 files changed, 56 insertions(+), 15 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/11422.rst diff --git a/doc/build/changelog/unreleased_20/11422.rst b/doc/build/changelog/unreleased_20/11422.rst new file mode 100644 index 0000000000..bde7879338 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11422.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, sql + :tickets: 11422 + + Fixed issue when serializing an :func:`_sql.over` clause with + unbounded range or rows. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 080011eb7d..1fd2d99233 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -14,7 +14,7 @@ from __future__ import annotations from decimal import Decimal -from enum import IntEnum +from enum import Enum import itertools import operator import re @@ -4149,7 +4149,7 @@ class _OverrideBinds(Grouping[_T]): return ck -class _OverRange(IntEnum): +class _OverRange(Enum): RANGE_UNBOUNDED = 0 RANGE_CURRENT = 1 @@ -4157,6 +4157,8 @@ class _OverRange(IntEnum): RANGE_UNBOUNDED = _OverRange.RANGE_UNBOUNDED RANGE_CURRENT = _OverRange.RANGE_CURRENT +_IntOrRange = Union[int, _OverRange] + class Over(ColumnElement[_T]): """Represent an OVER clause. @@ -4185,7 +4187,8 @@ class Over(ColumnElement[_T]): """The underlying expression object to which this :class:`.Over` object refers.""" - range_: Optional[typing_Tuple[int, int]] + range_: Optional[typing_Tuple[_IntOrRange, _IntOrRange]] + rows: Optional[typing_Tuple[_IntOrRange, _IntOrRange]] def __init__( self, @@ -4230,19 +4233,24 @@ class Over(ColumnElement[_T]): ) def _interpret_range( - self, range_: typing_Tuple[Optional[int], Optional[int]] - ) -> typing_Tuple[int, int]: + self, + range_: typing_Tuple[Optional[_IntOrRange], Optional[_IntOrRange]], + ) -> typing_Tuple[_IntOrRange, _IntOrRange]: if not isinstance(range_, tuple) or len(range_) != 2: raise exc.ArgumentError("2-tuple expected for range/rows") - lower: int - upper: int + r0, r1 = range_ + + lower: _IntOrRange + upper: _IntOrRange - if range_[0] is None: + if r0 is None: lower = RANGE_UNBOUNDED + elif isinstance(r0, _OverRange): + lower = r0 else: try: - lower = int(range_[0]) + lower = int(r0) except ValueError as err: raise exc.ArgumentError( "Integer or None expected for range value" @@ -4251,11 +4259,13 @@ class Over(ColumnElement[_T]): if lower == 0: lower = RANGE_CURRENT - if range_[1] is None: + if r1 is None: upper = RANGE_UNBOUNDED + elif isinstance(r1, _OverRange): + upper = r1 else: try: - upper = int(range_[1]) + upper = int(r1) except ValueError as err: raise exc.ArgumentError( "Integer or None expected for range value" diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index a52c59e2d3..40544f3ba0 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -18,6 +18,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm import scoped_session from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import combinations from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing.entities import ComparableEntity @@ -279,6 +280,34 @@ class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest): dialect="default", ) + @combinations( + ( + lambda: func.max(users.c.name).over(range_=(None, 0)), + "max(users.name) OVER (RANGE BETWEEN UNBOUNDED " + "PRECEDING AND CURRENT ROW)", + ), + ( + lambda: func.max(users.c.name).over(range_=(0, None)), + "max(users.name) OVER (RANGE BETWEEN CURRENT " + "ROW AND UNBOUNDED FOLLOWING)", + ), + ( + lambda: func.max(users.c.name).over(rows=(None, 0)), + "max(users.name) OVER (ROWS BETWEEN UNBOUNDED " + "PRECEDING AND CURRENT ROW)", + ), + ( + lambda: func.max(users.c.name).over(rows=(0, None)), + "max(users.name) OVER (ROWS BETWEEN CURRENT " + "ROW AND UNBOUNDED FOLLOWING)", + ), + ) + def test_over(self, over_fn, sql): + o = over_fn() + self.assert_compile(o, sql) + ol = serializer.loads(serializer.dumps(o), users.metadata) + self.assert_compile(ol, sql) + class ColumnPropertyWParamTest( AssertsCompiledSQL, fixtures.DeclarativeMappedTest @@ -331,7 +360,3 @@ class ColumnPropertyWParamTest( "CAST(left(test.some_id, :left_2) AS INTEGER) = :param_1", checkparams={"left_1": 6, "left_2": 6, "param_1": 123456}, ) - - -if __name__ == "__main__": - testing.main() -- 2.47.2