From: Federico Caselli Date: Wed, 29 May 2024 20:18:50 +0000 (+0200) Subject: Add missing function element methods X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=57bba096599ff10be008283261054e46c9d08d0b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add missing function element methods Added missing methods :meth:`_sql.FunctionFilter.within_group` and :meth:`_sql.WithinGroup.filter` Fixes: #11423 Change-Id: I4bafd9e3cab5883b28b2b997269df239739a2212 --- diff --git a/doc/build/changelog/unreleased_20/11423.rst b/doc/build/changelog/unreleased_20/11423.rst new file mode 100644 index 0000000000..ed6f988460 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11423.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, sql + :tickets: 11423 + + Added missing methods :meth:`_sql.FunctionFilter.within_group` + and :meth:`_sql.WithinGroup.filter` diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 080011eb7d..cb43d11a1b 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -4313,7 +4313,7 @@ class WithinGroup(ColumnElement[_T]): def __init__( self, - element: FunctionElement[_T], + element: Union[FunctionElement[_T], FunctionFilter[_T]], *order_by: _ColumnExpressionArgument[Any], ): self.element = element @@ -4327,7 +4327,14 @@ class WithinGroup(ColumnElement[_T]): tuple(self.order_by) if self.order_by is not None else () ) - def over(self, partition_by=None, order_by=None, range_=None, rows=None): + def over( + self, + *, + partition_by: Optional[_ByArgument] = None, + order_by: Optional[_ByArgument] = None, + rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + ) -> Over[_T]: """Produce an OVER clause against this :class:`.WithinGroup` construct. @@ -4343,6 +4350,24 @@ class WithinGroup(ColumnElement[_T]): rows=rows, ) + @overload + def filter(self) -> Self: ... + + @overload + def filter( + self, + __criterion0: _ColumnExpressionArgument[bool], + *criterion: _ColumnExpressionArgument[bool], + ) -> FunctionFilter[_T]: ... + + def filter( + self, *criterion: _ColumnExpressionArgument[bool] + ) -> Union[Self, FunctionFilter[_T]]: + """Produce a FILTER clause against this function.""" + if not criterion: + return self + return FunctionFilter(self, *criterion) + if not TYPE_CHECKING: @util.memoized_property @@ -4395,7 +4420,7 @@ class FunctionFilter(ColumnElement[_T]): def __init__( self, - func: FunctionElement[_T], + func: Union[FunctionElement[_T], WithinGroup[_T]], *criterion: _ColumnExpressionArgument[bool], ): self.func = func @@ -4465,6 +4490,19 @@ class FunctionFilter(ColumnElement[_T]): rows=rows, ) + def within_group( + self, *order_by: _ColumnExpressionArgument[Any] + ) -> WithinGroup[_T]: + """Produce a WITHIN GROUP (ORDER BY expr) clause against + this function. + """ + return WithinGroup(self, *order_by) + + def within_group_type( + self, within_group: WithinGroup[_T] + ) -> Optional[TypeEngine[_T]]: + return None + def self_group( self, against: Optional[OperatorType] = None ) -> Union[Self, Grouping[_T]]: diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index c47601b761..7782f215bc 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -845,6 +845,18 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_funcfilter_within_group(self): + self.assert_compile( + select( + func.rank() + .filter(table1.c.name > "foo") + .within_group(table1.c.name) + ), + "SELECT rank() FILTER (WHERE mytable.name > :name_1) " + "WITHIN GROUP (ORDER BY mytable.name) " + "AS anon_1 FROM mytable", + ) + + def test_within_group(self): stmt = select( table1.c.myid, func.percentile_cont(0.5).within_group(table1.c.name), @@ -858,7 +870,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): {"percentile_cont_1": 0.5}, ) - def test_funcfilter_within_group_multi(self): + def test_within_group_multi(self): stmt = select( table1.c.myid, func.percentile_cont(0.5).within_group( @@ -874,7 +886,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): {"percentile_cont_1": 0.5}, ) - def test_funcfilter_within_group_desc(self): + def test_within_group_desc(self): stmt = select( table1.c.myid, func.percentile_cont(0.5).within_group(table1.c.name.desc()), @@ -888,7 +900,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): {"percentile_cont_1": 0.5}, ) - def test_funcfilter_within_group_w_over(self): + def test_within_group_w_over(self): stmt = select( table1.c.myid, func.percentile_cont(0.5) @@ -904,6 +916,23 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): {"percentile_cont_1": 0.5}, ) + def test_within_group_filter(self): + stmt = select( + table1.c.myid, + func.percentile_cont(0.5) + .within_group(table1.c.name) + .filter(table1.c.myid > 42), + ) + self.assert_compile( + stmt, + "SELECT mytable.myid, percentile_cont(:percentile_cont_1) " + "WITHIN GROUP (ORDER BY mytable.name) " + "FILTER (WHERE mytable.myid > :myid_1) " + "AS anon_1 " + "FROM mytable", + {"percentile_cont_1": 0.5, "myid_1": 42}, + ) + def test_incorrect_none_type(self): from sqlalchemy.sql.expression import FunctionElement diff --git a/test/typing/plain_files/sql/functions_again.py b/test/typing/plain_files/sql/functions_again.py index 09e5e75f69..c3acf0ed27 100644 --- a/test/typing/plain_files/sql/functions_again.py +++ b/test/typing/plain_files/sql/functions_again.py @@ -18,7 +18,8 @@ class Foo(Base): c: Mapped[str] -func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc()) +# EXPECTED_TYPE: Over[Any] +reveal_type(func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc())) func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()]) func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()]) func.row_number().over(order_by="a", partition_by=("a", "b")) @@ -29,17 +30,23 @@ func.row_number().over(partition_by="a", order_by=("a", "b")) reveal_type(func.row_number().filter()) # EXPECTED_TYPE: FunctionFilter[Any] reveal_type(func.row_number().filter(Foo.a > 0)) - +# EXPECTED_TYPE: FunctionFilter[Any] +reveal_type(func.row_number().within_group(Foo.a).filter(Foo.b < 0)) +# EXPECTED_TYPE: WithinGroup[Any] +reveal_type(func.row_number().within_group(Foo.a)) +# EXPECTED_TYPE: WithinGroup[Any] +reveal_type(func.row_number().filter(Foo.a > 0).within_group(Foo.a)) +# EXPECTED_TYPE: Over[Any] +reveal_type(func.row_number().filter(Foo.a > 0).over()) +# EXPECTED_TYPE: Over[Any] +reveal_type(func.row_number().within_group(Foo.a).over()) # test #10801 # EXPECTED_TYPE: max[int] reveal_type(func.max(Foo.b)) -stmt1 = select( - Foo.a, - func.min(Foo.b), -).group_by(Foo.a) +stmt1 = select(Foo.a, func.min(Foo.b)).group_by(Foo.a) # EXPECTED_TYPE: Select[int, int] reveal_type(stmt1) @@ -48,10 +55,7 @@ reveal_type(stmt1) reveal_type(func.coalesce(Foo.c, "a", "b")) -stmt2 = select( - Foo.a, - func.coalesce(Foo.c, "a", "b"), -).group_by(Foo.a) +stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a) # EXPECTED_TYPE: Select[int, str] reveal_type(stmt2)