]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add missing function element methods
authorFederico Caselli <cfederico87@gmail.com>
Wed, 29 May 2024 20:18:50 +0000 (22:18 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 30 May 2024 20:17:48 +0000 (22:17 +0200)
Added missing methods :meth:`_sql.FunctionFilter.within_group`
and :meth:`_sql.WithinGroup.filter`

Fixes: #11423
Change-Id: I4bafd9e3cab5883b28b2b997269df239739a2212

doc/build/changelog/unreleased_20/11423.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
test/sql/test_functions.py
test/typing/plain_files/sql/functions_again.py

diff --git a/doc/build/changelog/unreleased_20/11423.rst b/doc/build/changelog/unreleased_20/11423.rst
new file mode 100644 (file)
index 0000000..ed6f988
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 11423
+
+    Added missing methods :meth:`_sql.FunctionFilter.within_group`
+    and :meth:`_sql.WithinGroup.filter`
index 080011eb7d0b8ca1e16f20ba23cc6449e0e896e8..cb43d11a1b208e3860b7f51a49af4e912f285a08 100644 (file)
@@ -4313,7 +4313,7 @@ class WithinGroup(ColumnElement[_T]):
 
     def __init__(
         self,
-        element: FunctionElement[_T],
+        element: Union[FunctionElement[_T], FunctionFilter[_T]],
         *order_by: _ColumnExpressionArgument[Any],
     ):
         self.element = element
@@ -4327,7 +4327,14 @@ class WithinGroup(ColumnElement[_T]):
             tuple(self.order_by) if self.order_by is not None else ()
         )
 
-    def over(self, partition_by=None, order_by=None, range_=None, rows=None):
+    def over(
+        self,
+        *,
+        partition_by: Optional[_ByArgument] = None,
+        order_by: Optional[_ByArgument] = None,
+        rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+        range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+    ) -> Over[_T]:
         """Produce an OVER clause against this :class:`.WithinGroup`
         construct.
 
@@ -4343,6 +4350,24 @@ class WithinGroup(ColumnElement[_T]):
             rows=rows,
         )
 
+    @overload
+    def filter(self) -> Self: ...
+
+    @overload
+    def filter(
+        self,
+        __criterion0: _ColumnExpressionArgument[bool],
+        *criterion: _ColumnExpressionArgument[bool],
+    ) -> FunctionFilter[_T]: ...
+
+    def filter(
+        self, *criterion: _ColumnExpressionArgument[bool]
+    ) -> Union[Self, FunctionFilter[_T]]:
+        """Produce a FILTER clause against this function."""
+        if not criterion:
+            return self
+        return FunctionFilter(self, *criterion)
+
     if not TYPE_CHECKING:
 
         @util.memoized_property
@@ -4395,7 +4420,7 @@ class FunctionFilter(ColumnElement[_T]):
 
     def __init__(
         self,
-        func: FunctionElement[_T],
+        func: Union[FunctionElement[_T], WithinGroup[_T]],
         *criterion: _ColumnExpressionArgument[bool],
     ):
         self.func = func
@@ -4465,6 +4490,19 @@ class FunctionFilter(ColumnElement[_T]):
             rows=rows,
         )
 
+    def within_group(
+        self, *order_by: _ColumnExpressionArgument[Any]
+    ) -> WithinGroup[_T]:
+        """Produce a WITHIN GROUP (ORDER BY expr) clause against
+        this function.
+        """
+        return WithinGroup(self, *order_by)
+
+    def within_group_type(
+        self, within_group: WithinGroup[_T]
+    ) -> Optional[TypeEngine[_T]]:
+        return None
+
     def self_group(
         self, against: Optional[OperatorType] = None
     ) -> Union[Self, Grouping[_T]]:
index c47601b7616622c9c66a78eed982fe93063ef7eb..7782f215bcd16df861cdf1c6938fddffbf1eaa74 100644 (file)
@@ -845,6 +845,18 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
     def test_funcfilter_within_group(self):
+        self.assert_compile(
+            select(
+                func.rank()
+                .filter(table1.c.name > "foo")
+                .within_group(table1.c.name)
+            ),
+            "SELECT rank() FILTER (WHERE mytable.name > :name_1) "
+            "WITHIN GROUP (ORDER BY mytable.name) "
+            "AS anon_1 FROM mytable",
+        )
+
+    def test_within_group(self):
         stmt = select(
             table1.c.myid,
             func.percentile_cont(0.5).within_group(table1.c.name),
@@ -858,7 +870,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             {"percentile_cont_1": 0.5},
         )
 
-    def test_funcfilter_within_group_multi(self):
+    def test_within_group_multi(self):
         stmt = select(
             table1.c.myid,
             func.percentile_cont(0.5).within_group(
@@ -874,7 +886,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             {"percentile_cont_1": 0.5},
         )
 
-    def test_funcfilter_within_group_desc(self):
+    def test_within_group_desc(self):
         stmt = select(
             table1.c.myid,
             func.percentile_cont(0.5).within_group(table1.c.name.desc()),
@@ -888,7 +900,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             {"percentile_cont_1": 0.5},
         )
 
-    def test_funcfilter_within_group_w_over(self):
+    def test_within_group_w_over(self):
         stmt = select(
             table1.c.myid,
             func.percentile_cont(0.5)
@@ -904,6 +916,23 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             {"percentile_cont_1": 0.5},
         )
 
+    def test_within_group_filter(self):
+        stmt = select(
+            table1.c.myid,
+            func.percentile_cont(0.5)
+            .within_group(table1.c.name)
+            .filter(table1.c.myid > 42),
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT mytable.myid, percentile_cont(:percentile_cont_1) "
+            "WITHIN GROUP (ORDER BY mytable.name) "
+            "FILTER (WHERE mytable.myid > :myid_1) "
+            "AS anon_1 "
+            "FROM mytable",
+            {"percentile_cont_1": 0.5, "myid_1": 42},
+        )
+
     def test_incorrect_none_type(self):
         from sqlalchemy.sql.expression import FunctionElement
 
index 09e5e75f69e89a54edbb87b938e94885a0920a4a..c3acf0ed270872b885c858d7ee1bd67e3056add2 100644 (file)
@@ -18,7 +18,8 @@ class Foo(Base):
     c: Mapped[str]
 
 
-func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc())
+# EXPECTED_TYPE: Over[Any]
+reveal_type(func.row_number().over(order_by=Foo.a, partition_by=Foo.b.desc()))
 func.row_number().over(order_by=[Foo.a.desc(), Foo.b.desc()])
 func.row_number().over(partition_by=[Foo.a.desc(), Foo.b.desc()])
 func.row_number().over(order_by="a", partition_by=("a", "b"))
@@ -29,17 +30,23 @@ func.row_number().over(partition_by="a", order_by=("a", "b"))
 reveal_type(func.row_number().filter())
 # EXPECTED_TYPE: FunctionFilter[Any]
 reveal_type(func.row_number().filter(Foo.a > 0))
-
+# EXPECTED_TYPE: FunctionFilter[Any]
+reveal_type(func.row_number().within_group(Foo.a).filter(Foo.b < 0))
+# EXPECTED_TYPE: WithinGroup[Any]
+reveal_type(func.row_number().within_group(Foo.a))
+# EXPECTED_TYPE: WithinGroup[Any]
+reveal_type(func.row_number().filter(Foo.a > 0).within_group(Foo.a))
+# EXPECTED_TYPE: Over[Any]
+reveal_type(func.row_number().filter(Foo.a > 0).over())
+# EXPECTED_TYPE: Over[Any]
+reveal_type(func.row_number().within_group(Foo.a).over())
 
 # test #10801
 # EXPECTED_TYPE: max[int]
 reveal_type(func.max(Foo.b))
 
 
-stmt1 = select(
-    Foo.a,
-    func.min(Foo.b),
-).group_by(Foo.a)
+stmt1 = select(Foo.a, func.min(Foo.b)).group_by(Foo.a)
 # EXPECTED_TYPE: Select[int, int]
 reveal_type(stmt1)
 
@@ -48,10 +55,7 @@ reveal_type(stmt1)
 reveal_type(func.coalesce(Foo.c, "a", "b"))
 
 
-stmt2 = select(
-    Foo.a,
-    func.coalesce(Foo.c, "a", "b"),
-).group_by(Foo.a)
+stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a)
 # EXPECTED_TYPE: Select[int, str]
 reveal_type(stmt2)