]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix Over serialization
authorFederico Caselli <cfederico87@gmail.com>
Wed, 29 May 2024 20:03:17 +0000 (22:03 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 29 May 2024 20:03:17 +0000 (22:03 +0200)
Fixed issue when serializing an :func:`_sql.over` clause with
unbounded range or rows.

Fixes: #11422
Change-Id: I52a9f72205fd9c7ef5620596c83551e73d5cee5b

doc/build/changelog/unreleased_20/11422.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
test/ext/test_serializer.py

diff --git a/doc/build/changelog/unreleased_20/11422.rst b/doc/build/changelog/unreleased_20/11422.rst
new file mode 100644 (file)
index 0000000..bde7879
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 11422
+
+    Fixed issue when serializing an :func:`_sql.over` clause with
+    unbounded range or rows.
index 080011eb7d0b8ca1e16f20ba23cc6449e0e896e8..1fd2d992338f7199beeef5dbf2bb808267348b12 100644 (file)
@@ -14,7 +14,7 @@
 from __future__ import annotations
 
 from decimal import Decimal
-from enum import IntEnum
+from enum import Enum
 import itertools
 import operator
 import re
@@ -4149,7 +4149,7 @@ class _OverrideBinds(Grouping[_T]):
         return ck
 
 
-class _OverRange(IntEnum):
+class _OverRange(Enum):
     RANGE_UNBOUNDED = 0
     RANGE_CURRENT = 1
 
@@ -4157,6 +4157,8 @@ class _OverRange(IntEnum):
 RANGE_UNBOUNDED = _OverRange.RANGE_UNBOUNDED
 RANGE_CURRENT = _OverRange.RANGE_CURRENT
 
+_IntOrRange = Union[int, _OverRange]
+
 
 class Over(ColumnElement[_T]):
     """Represent an OVER clause.
@@ -4185,7 +4187,8 @@ class Over(ColumnElement[_T]):
     """The underlying expression object to which this :class:`.Over`
     object refers."""
 
-    range_: Optional[typing_Tuple[int, int]]
+    range_: Optional[typing_Tuple[_IntOrRange, _IntOrRange]]
+    rows: Optional[typing_Tuple[_IntOrRange, _IntOrRange]]
 
     def __init__(
         self,
@@ -4230,19 +4233,24 @@ class Over(ColumnElement[_T]):
         )
 
     def _interpret_range(
-        self, range_: typing_Tuple[Optional[int], Optional[int]]
-    ) -> typing_Tuple[int, int]:
+        self,
+        range_: typing_Tuple[Optional[_IntOrRange], Optional[_IntOrRange]],
+    ) -> typing_Tuple[_IntOrRange, _IntOrRange]:
         if not isinstance(range_, tuple) or len(range_) != 2:
             raise exc.ArgumentError("2-tuple expected for range/rows")
 
-        lower: int
-        upper: int
+        r0, r1 = range_
+
+        lower: _IntOrRange
+        upper: _IntOrRange
 
-        if range_[0] is None:
+        if r0 is None:
             lower = RANGE_UNBOUNDED
+        elif isinstance(r0, _OverRange):
+            lower = r0
         else:
             try:
-                lower = int(range_[0])
+                lower = int(r0)
             except ValueError as err:
                 raise exc.ArgumentError(
                     "Integer or None expected for range value"
@@ -4251,11 +4259,13 @@ class Over(ColumnElement[_T]):
                 if lower == 0:
                     lower = RANGE_CURRENT
 
-        if range_[1] is None:
+        if r1 is None:
             upper = RANGE_UNBOUNDED
+        elif isinstance(r1, _OverRange):
+            upper = r1
         else:
             try:
-                upper = int(range_[1])
+                upper = int(r1)
             except ValueError as err:
                 raise exc.ArgumentError(
                     "Integer or None expected for range value"
index a52c59e2d34d23c3b88eedd75f48f77452bd5019..40544f3ba03632c2fb2a5c09a5f7c0e8697690f0 100644 (file)
@@ -18,6 +18,7 @@ from sqlalchemy.orm import relationship
 from sqlalchemy.orm import scoped_session
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import combinations
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.entities import ComparableEntity
@@ -279,6 +280,34 @@ class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest):
             dialect="default",
         )
 
+    @combinations(
+        (
+            lambda: func.max(users.c.name).over(range_=(None, 0)),
+            "max(users.name) OVER (RANGE BETWEEN UNBOUNDED "
+            "PRECEDING AND CURRENT ROW)",
+        ),
+        (
+            lambda: func.max(users.c.name).over(range_=(0, None)),
+            "max(users.name) OVER (RANGE BETWEEN CURRENT "
+            "ROW AND UNBOUNDED FOLLOWING)",
+        ),
+        (
+            lambda: func.max(users.c.name).over(rows=(None, 0)),
+            "max(users.name) OVER (ROWS BETWEEN UNBOUNDED "
+            "PRECEDING AND CURRENT ROW)",
+        ),
+        (
+            lambda: func.max(users.c.name).over(rows=(0, None)),
+            "max(users.name) OVER (ROWS BETWEEN CURRENT "
+            "ROW AND UNBOUNDED FOLLOWING)",
+        ),
+    )
+    def test_over(self, over_fn, sql):
+        o = over_fn()
+        self.assert_compile(o, sql)
+        ol = serializer.loads(serializer.dumps(o), users.metadata)
+        self.assert_compile(ol, sql)
+
 
 class ColumnPropertyWParamTest(
     AssertsCompiledSQL, fixtures.DeclarativeMappedTest
@@ -331,7 +360,3 @@ class ColumnPropertyWParamTest(
             "CAST(left(test.some_id, :left_2) AS INTEGER) = :param_1",
             checkparams={"left_1": 6, "left_2": 6, "param_1": 123456},
         )
-
-
-if __name__ == "__main__":
-    testing.main()