From: Kaan Date: Tue, 18 Mar 2025 12:46:52 +0000 (+0000) Subject: implement GROUPS frame spec for window functions X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8d57d0061fd67f5277b7d52a16814d7ffc73ef73;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement GROUPS frame spec for window functions --- diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 799c87c82b..f06aeefbe6 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -1500,6 +1500,7 @@ def over( order_by: Optional[_ByArgument] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ) -> Over[_T]: r"""Produce an :class:`.Over` object against a function. @@ -1562,10 +1563,12 @@ def over( :param range\_: optional range clause for the window. This is a tuple value which can contain integer values or ``None``, and will render a RANGE BETWEEN PRECEDING / FOLLOWING clause. - :param rows: optional rows clause for the window. This is a tuple value which can contain integer values or None, and will render a ROWS BETWEEN PRECEDING / FOLLOWING clause. + :param groups: optional groups clause for the window. This is a + tuple value which can contain integer values or ``None``, + and will render a GROUPS BETWEEN PRECEDING / FOLLOWING clause. This function is also available from the :data:`~.expression.func` construct itself via the :meth:`.FunctionElement.over` method. @@ -1579,7 +1582,7 @@ def over( :func:`_expression.within_group` """ # noqa: E501 - return Over(element, partition_by, order_by, range_, rows) + return Over(element, partition_by, order_by, range_, rows, groups) @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`") diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 20073a3afa..cb25f880cb 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2880,6 +2880,8 @@ class SQLCompiler(Compiled): range_ = f"RANGE BETWEEN {self.process(over.range_, **kwargs)}" elif over.rows is not None: range_ = f"ROWS BETWEEN {self.process(over.rows, **kwargs)}" + elif over.groups is not None: + range_ = f"GROUPS BETWEEN {self.process(over.groups, **kwargs)}" else: range_ = None diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index c9aac427db..4d229dbb4a 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -4212,6 +4212,7 @@ class Over(ColumnElement[_T]): ("partition_by", InternalTraversal.dp_clauseelement), ("range_", InternalTraversal.dp_clauseelement), ("rows", InternalTraversal.dp_clauseelement), + ("groups", InternalTraversal.dp_clauseelement), ] order_by: Optional[ClauseList] = None @@ -4223,6 +4224,7 @@ class Over(ColumnElement[_T]): range_: Optional[_FrameClause] rows: Optional[_FrameClause] + groups: Optional[_FrameClause] def __init__( self, @@ -4231,6 +4233,7 @@ class Over(ColumnElement[_T]): order_by: Optional[_ByArgument] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ): self.element = element if order_by is not None: @@ -4243,19 +4246,37 @@ class Over(ColumnElement[_T]): _literal_as_text_role=roles.ByOfRole, ) + self.range_, self.rows, self.groups = None, None, None if range_: self.range_ = _FrameClause(range_) if rows: raise exc.ArgumentError( "'range_' and 'rows' are mutually exclusive" ) - else: - self.rows = None - elif rows: + if groups: + raise exc.ArgumentError( + "'range_' and 'groups' are mutually exclusive" + ) + if rows: self.rows = _FrameClause(rows) - self.range_ = None - else: - self.rows = self.range_ = None + if range_: + raise exc.ArgumentError( + "'rows' and 'range_' are mutually exclusive" + ) + if groups: + raise exc.ArgumentError( + "'rows' and 'groups' are mutually exclusive" + ) + if groups: + self.groups = _FrameClause(groups) + if range_: + raise exc.ArgumentError( + "'groups' and 'range_' are mutually exclusive" + ) + if rows: + raise exc.ArgumentError( + "'groups' and 'rows' are mutually exclusive" + ) if not TYPE_CHECKING: @@ -4409,6 +4430,7 @@ class WithinGroup(ColumnElement[_T]): order_by: Optional[_ByArgument] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ) -> Over[_T]: """Produce an OVER clause against this :class:`.WithinGroup` construct. @@ -4423,6 +4445,7 @@ class WithinGroup(ColumnElement[_T]): order_by=order_by, range_=range_, rows=rows, + groups=groups, ) @overload @@ -4540,6 +4563,7 @@ class FunctionFilter(Generative, ColumnElement[_T]): ] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ) -> Over[_T]: """Produce an OVER clause against this filtered function. @@ -4565,6 +4589,7 @@ class FunctionFilter(Generative, ColumnElement[_T]): order_by=order_by, range_=range_, rows=rows, + groups=groups, ) def within_group( diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 87a68cfd90..7148d28281 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -435,6 +435,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): order_by: Optional[_ByArgument] = None, rows: Optional[Tuple[Optional[int], Optional[int]]] = None, range_: Optional[Tuple[Optional[int], Optional[int]]] = None, + groups: Optional[Tuple[Optional[int], Optional[int]]] = None, ) -> Over[_T]: """Produce an OVER clause against this function. @@ -466,6 +467,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): order_by=order_by, rows=rows, range_=range_, + groups=groups, ) def within_group( diff --git a/test/ext/test_serializer.py b/test/ext/test_serializer.py index 40544f3ba0..fb92c752a6 100644 --- a/test/ext/test_serializer.py +++ b/test/ext/test_serializer.py @@ -301,6 +301,16 @@ class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest): "max(users.name) OVER (ROWS BETWEEN CURRENT " "ROW AND UNBOUNDED FOLLOWING)", ), + ( + lambda: func.max(users.c.name).over(groups=(None, 0)), + "max(users.name) OVER (GROUPS BETWEEN UNBOUNDED " + "PRECEDING AND CURRENT ROW)", + ), + ( + lambda: func.max(users.c.name).over(groups=(0, None)), + "max(users.name) OVER (GROUPS BETWEEN CURRENT " + "ROW AND UNBOUNDED FOLLOWING)", + ), ) def test_over(self, over_fn, sql): o = over_fn() diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 9d74a8d2f4..3dc73484f0 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -3208,6 +3208,41 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): checkparams={"param_1": 10, "param_2": 1}, ) + self.assert_compile( + select(func.row_number().over(order_by=expr, groups=(None, 0))), + "SELECT row_number() OVER " + "(ORDER BY mytable.myid GROUPS BETWEEN " + "UNBOUNDED PRECEDING AND CURRENT ROW)" + " AS anon_1 FROM mytable", + ) + + self.assert_compile( + select(func.row_number().over(order_by=expr, groups=(-5, 10))), + "SELECT row_number() OVER " + "(ORDER BY mytable.myid GROUPS BETWEEN " + ":param_1 PRECEDING AND :param_2 FOLLOWING)" + " AS anon_1 FROM mytable", + checkparams={"param_1": 5, "param_2": 10}, + ) + + self.assert_compile( + select(func.row_number().over(order_by=expr, groups=(1, 10))), + "SELECT row_number() OVER " + "(ORDER BY mytable.myid GROUPS BETWEEN " + ":param_1 FOLLOWING AND :param_2 FOLLOWING)" + " AS anon_1 FROM mytable", + checkparams={"param_1": 1, "param_2": 10}, + ) + + self.assert_compile( + select(func.row_number().over(order_by=expr, groups=(-10, -1))), + "SELECT row_number() OVER " + "(ORDER BY mytable.myid GROUPS BETWEEN " + ":param_1 PRECEDING AND :param_2 PRECEDING)" + " AS anon_1 FROM mytable", + checkparams={"param_1": 10, "param_2": 1}, + ) + def test_over_invalid_framespecs(self): assert_raises_message( exc.ArgumentError, diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index 163df0a0d7..28cdb03a96 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -844,6 +844,34 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "AS anon_1 FROM mytable", ) + def test_funcfilter_windowing_groups(self): + self.assert_compile( + select( + func.rank() + .filter(table1.c.name > "foo") + .over(groups=(1, 5), partition_by=["description"]) + ), + "SELECT rank() FILTER (WHERE mytable.name > :name_1) " + "OVER (PARTITION BY mytable.description GROUPS BETWEEN :param_1 " + "FOLLOWING AND :param_2 FOLLOWING) " + "AS anon_1 FROM mytable", + ) + + def test_funcfilter_windowing_groups_positional(self): + self.assert_compile( + select( + func.rank() + .filter(table1.c.name > "foo") + .over(groups=(1, 5), partition_by=["description"]) + ), + "SELECT rank() FILTER (WHERE mytable.name > ?) " + "OVER (PARTITION BY mytable.description GROUPS BETWEEN ? " + "FOLLOWING AND ? FOLLOWING) " + "AS anon_1 FROM mytable", + checkpositional=("foo", 1, 5), + dialect="default_qmark", + ) + def test_funcfilter_more_criteria(self): ff = func.rank().filter(table1.c.name > "foo") ff2 = ff.filter(table1.c.myid == 1)