]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement GROUPS frame spec for window functions
authorKaan <kaan191@gmail.com>
Tue, 18 Mar 2025 12:46:52 +0000 (12:46 +0000)
committerKaan <kaan191@gmail.com>
Tue, 18 Mar 2025 12:46:52 +0000 (12:46 +0000)
lib/sqlalchemy/sql/_elements_constructors.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
test/ext/test_serializer.py
test/sql/test_compiler.py
test/sql/test_functions.py

index 799c87c82ba631d7b2b6a6395b56b6658787c06d..f06aeefbe6f53147dd7b3ce8671d379fce5af67e 100644 (file)
@@ -1500,6 +1500,7 @@ def over(
     order_by: Optional[_ByArgument] = None,
     range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
     rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+    groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
 ) -> Over[_T]:
     r"""Produce an :class:`.Over` object against a function.
 
@@ -1562,10 +1563,12 @@ def over(
     :param range\_: optional range clause for the window.  This is a
      tuple value which can contain integer values or ``None``,
      and will render a RANGE BETWEEN PRECEDING / FOLLOWING clause.
-
     :param rows: optional rows clause for the window.  This is a tuple
      value which can contain integer values or None, and will render
      a ROWS BETWEEN PRECEDING / FOLLOWING clause.
+    :param groups: optional groups clause for the window.  This is a
+     tuple value which can contain integer values or ``None``,
+     and will render a GROUPS BETWEEN PRECEDING / FOLLOWING clause.
 
     This function is also available from the :data:`~.expression.func`
     construct itself via the :meth:`.FunctionElement.over` method.
@@ -1579,7 +1582,7 @@ def over(
         :func:`_expression.within_group`
 
     """  # noqa: E501
-    return Over(element, partition_by, order_by, range_, rows)
+    return Over(element, partition_by, order_by, range_, rows, groups)
 
 
 @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`")
index 20073a3afaa0a75390fe0f5881ec8be45cf83ef1..cb25f880cb4549239a0146b5858106e28df56bd6 100644 (file)
@@ -2880,6 +2880,8 @@ class SQLCompiler(Compiled):
             range_ = f"RANGE BETWEEN {self.process(over.range_, **kwargs)}"
         elif over.rows is not None:
             range_ = f"ROWS BETWEEN {self.process(over.rows, **kwargs)}"
+        elif over.groups is not None:
+            range_ = f"GROUPS BETWEEN {self.process(over.groups, **kwargs)}"
         else:
             range_ = None
 
index c9aac427dbefbd4646694012b4e05d6400afa186..4d229dbb4af8c68d3bcbf6e6e52f302bad1fe868 100644 (file)
@@ -4212,6 +4212,7 @@ class Over(ColumnElement[_T]):
         ("partition_by", InternalTraversal.dp_clauseelement),
         ("range_", InternalTraversal.dp_clauseelement),
         ("rows", InternalTraversal.dp_clauseelement),
+        ("groups", InternalTraversal.dp_clauseelement),
     ]
 
     order_by: Optional[ClauseList] = None
@@ -4223,6 +4224,7 @@ class Over(ColumnElement[_T]):
 
     range_: Optional[_FrameClause]
     rows: Optional[_FrameClause]
+    groups: Optional[_FrameClause]
 
     def __init__(
         self,
@@ -4231,6 +4233,7 @@ class Over(ColumnElement[_T]):
         order_by: Optional[_ByArgument] = None,
         range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
         rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+        groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
     ):
         self.element = element
         if order_by is not None:
@@ -4243,19 +4246,37 @@ class Over(ColumnElement[_T]):
                 _literal_as_text_role=roles.ByOfRole,
             )
 
+        self.range_, self.rows, self.groups = None, None, None
         if range_:
             self.range_ = _FrameClause(range_)
             if rows:
                 raise exc.ArgumentError(
                     "'range_' and 'rows' are mutually exclusive"
                 )
-            else:
-                self.rows = None
-        elif rows:
+            if groups:
+                raise exc.ArgumentError(
+                    "'range_' and 'groups' are mutually exclusive"
+                )
+        if rows:
             self.rows = _FrameClause(rows)
-            self.range_ = None
-        else:
-            self.rows = self.range_ = None
+            if range_:
+                raise exc.ArgumentError(
+                    "'rows' and 'range_' are mutually exclusive"
+                )
+            if groups:
+                raise exc.ArgumentError(
+                    "'rows' and 'groups' are mutually exclusive"
+                )
+        if groups:
+            self.groups = _FrameClause(groups)
+            if range_:
+                raise exc.ArgumentError(
+                    "'groups' and 'range_' are mutually exclusive"
+                )
+            if rows:
+                raise exc.ArgumentError(
+                    "'groups' and 'rows' are mutually exclusive"
+                )
 
     if not TYPE_CHECKING:
 
@@ -4409,6 +4430,7 @@ class WithinGroup(ColumnElement[_T]):
         order_by: Optional[_ByArgument] = None,
         rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
         range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+        groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
     ) -> Over[_T]:
         """Produce an OVER clause against this :class:`.WithinGroup`
         construct.
@@ -4423,6 +4445,7 @@ class WithinGroup(ColumnElement[_T]):
             order_by=order_by,
             range_=range_,
             rows=rows,
+            groups=groups,
         )
 
     @overload
@@ -4540,6 +4563,7 @@ class FunctionFilter(Generative, ColumnElement[_T]):
         ] = None,
         range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
         rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+        groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
     ) -> Over[_T]:
         """Produce an OVER clause against this filtered function.
 
@@ -4565,6 +4589,7 @@ class FunctionFilter(Generative, ColumnElement[_T]):
             order_by=order_by,
             range_=range_,
             rows=rows,
+            groups=groups,
         )
 
     def within_group(
index 87a68cfd90b8733cba6c14dd1a4e113c77ded5c8..7148d28281ff8bc87fb8bf97a1b060ce4a58b383 100644 (file)
@@ -435,6 +435,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         order_by: Optional[_ByArgument] = None,
         rows: Optional[Tuple[Optional[int], Optional[int]]] = None,
         range_: Optional[Tuple[Optional[int], Optional[int]]] = None,
+        groups: Optional[Tuple[Optional[int], Optional[int]]] = None,
     ) -> Over[_T]:
         """Produce an OVER clause against this function.
 
@@ -466,6 +467,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
             order_by=order_by,
             rows=rows,
             range_=range_,
+            groups=groups,
         )
 
     def within_group(
index 40544f3ba03632c2fb2a5c09a5f7c0e8697690f0..fb92c752a67707314ebc2d2804c891b490d4e73c 100644 (file)
@@ -301,6 +301,16 @@ class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest):
             "max(users.name) OVER (ROWS BETWEEN CURRENT "
             "ROW AND UNBOUNDED FOLLOWING)",
         ),
+        (
+            lambda: func.max(users.c.name).over(groups=(None, 0)),
+            "max(users.name) OVER (GROUPS BETWEEN UNBOUNDED "
+            "PRECEDING AND CURRENT ROW)",
+        ),
+        (
+            lambda: func.max(users.c.name).over(groups=(0, None)),
+            "max(users.name) OVER (GROUPS BETWEEN CURRENT "
+            "ROW AND UNBOUNDED FOLLOWING)",
+        ),
     )
     def test_over(self, over_fn, sql):
         o = over_fn()
index 9d74a8d2f4c609d363ec95911715f8d1ef3492e6..3dc73484f078875dbc110d1dd280ad95645d257f 100644 (file)
@@ -3208,6 +3208,41 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
             checkparams={"param_1": 10, "param_2": 1},
         )
 
+        self.assert_compile(
+            select(func.row_number().over(order_by=expr, groups=(None, 0))),
+            "SELECT row_number() OVER "
+            "(ORDER BY mytable.myid GROUPS BETWEEN "
+            "UNBOUNDED PRECEDING AND CURRENT ROW)"
+            " AS anon_1 FROM mytable",
+        )
+
+        self.assert_compile(
+            select(func.row_number().over(order_by=expr, groups=(-5, 10))),
+            "SELECT row_number() OVER "
+            "(ORDER BY mytable.myid GROUPS BETWEEN "
+            ":param_1 PRECEDING AND :param_2 FOLLOWING)"
+            " AS anon_1 FROM mytable",
+            checkparams={"param_1": 5, "param_2": 10},
+        )
+
+        self.assert_compile(
+            select(func.row_number().over(order_by=expr, groups=(1, 10))),
+            "SELECT row_number() OVER "
+            "(ORDER BY mytable.myid GROUPS BETWEEN "
+            ":param_1 FOLLOWING AND :param_2 FOLLOWING)"
+            " AS anon_1 FROM mytable",
+            checkparams={"param_1": 1, "param_2": 10},
+        )
+
+        self.assert_compile(
+            select(func.row_number().over(order_by=expr, groups=(-10, -1))),
+            "SELECT row_number() OVER "
+            "(ORDER BY mytable.myid GROUPS BETWEEN "
+            ":param_1 PRECEDING AND :param_2 PRECEDING)"
+            " AS anon_1 FROM mytable",
+            checkparams={"param_1": 10, "param_2": 1},
+        )
+
     def test_over_invalid_framespecs(self):
         assert_raises_message(
             exc.ArgumentError,
index 163df0a0d71b9aca353a89bdc9967a0297047a07..28cdb03a9657136af7a004de4052617d0c816a63 100644 (file)
@@ -844,6 +844,34 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "AS anon_1 FROM mytable",
         )
 
+    def test_funcfilter_windowing_groups(self):
+        self.assert_compile(
+            select(
+                func.rank()
+                .filter(table1.c.name > "foo")
+                .over(groups=(1, 5), partition_by=["description"])
+            ),
+            "SELECT rank() FILTER (WHERE mytable.name > :name_1) "
+            "OVER (PARTITION BY mytable.description GROUPS BETWEEN :param_1 "
+            "FOLLOWING AND :param_2 FOLLOWING) "
+            "AS anon_1 FROM mytable",
+        )
+
+    def test_funcfilter_windowing_groups_positional(self):
+        self.assert_compile(
+            select(
+                func.rank()
+                .filter(table1.c.name > "foo")
+                .over(groups=(1, 5), partition_by=["description"])
+            ),
+            "SELECT rank() FILTER (WHERE mytable.name > ?) "
+            "OVER (PARTITION BY mytable.description GROUPS BETWEEN ? "
+            "FOLLOWING AND ? FOLLOWING) "
+            "AS anon_1 FROM mytable",
+            checkpositional=("foo", 1, 5),
+            dialect="default_qmark",
+        )
+
     def test_funcfilter_more_criteria(self):
         ff = func.rank().filter(table1.c.name > "foo")
         ff2 = ff.filter(table1.c.myid == 1)