]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
test cases, fixes
authorViolet Folino Gallo <48537601+galloviolet@users.noreply.github.com>
Tue, 24 Jun 2025 19:44:14 +0000 (12:44 -0700)
committerViolet Folino Gallo <48537601+galloviolet@users.noreply.github.com>
Tue, 24 Jun 2025 19:59:40 +0000 (12:59 -0700)
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/functions.py
lib/sqlalchemy/testing/suite/test_select.py
test/sql/test_functions.py

index 691d28afac9710efd6cd4af7c13609b7a23d31f4..1e1b26e86325ecc7124832f7446ce445d52d05f4 100644 (file)
@@ -4315,34 +4315,35 @@ class _FrameClause(ClauseElement):
             lower_value, upper_value = range_
         except (ValueError, TypeError) as ve:
             raise exc.ArgumentError("2-tuple expected for range/rows") from ve
+        
+        lower_type = type_api.INTEGERTYPE if isinstance(lower_value, int) else type_api.NUMERICTYPE
+        upper_type = type_api.INTEGERTYPE if isinstance(upper_value, int) else type_api.NUMERICTYPE
 
         if lower_value is None:
             self.lower_type = _FrameClauseType.RANGE_UNBOUNDED
             self.lower_bind = None
+        elif lower_value == 0:
+            self.lower_type = _FrameClauseType.RANGE_CURRENT
+            self.lower_bind = None
+        elif lower_value < 0:
+            self.lower_type = _FrameClauseType.RANGE_PRECEDING
+            self.lower_bind = literal(abs(lower_value), lower_type)
         else:
-            if lower_value == 0:
-                self.lower_type = _FrameClauseType.RANGE_CURRENT
-                self.lower_bind = None
-            elif lower_value < 0:
-                self.lower_type = _FrameClauseType.RANGE_PRECEDING
-                self.lower_bind = literal(abs(lower_value), type_api.NULLTYPE)
-            else:
-                self.lower_type = _FrameClauseType.RANGE_FOLLOWING
-                self.lower_bind = literal(lower_value, type_api.NULLTYPE)
+            self.lower_type = _FrameClauseType.RANGE_FOLLOWING
+            self.lower_bind = literal(lower_value, lower_type)
 
         if upper_value is None:
             self.upper_type = _FrameClauseType.RANGE_UNBOUNDED
             self.upper_bind = None
+        elif upper_value == 0:
+            self.upper_type = _FrameClauseType.RANGE_CURRENT
+            self.upper_bind = None
+        elif upper_value < 0:
+            self.upper_type = _FrameClauseType.RANGE_PRECEDING
+            self.upper_bind = literal(abs(upper_value), upper_type)
         else:
-            if upper_value == 0:
-                self.upper_type = _FrameClauseType.RANGE_CURRENT
-                self.upper_bind = None
-            elif upper_value < 0:
-                self.upper_type = _FrameClauseType.RANGE_PRECEDING
-                self.upper_bind = literal(abs(upper_value), type_api.NULLTYPE)
-            else:
-                self.upper_type = _FrameClauseType.RANGE_FOLLOWING
-                self.upper_bind = literal(upper_value, type_api.NULLTYPE)
+            self.upper_type = _FrameClauseType.RANGE_FOLLOWING
+            self.upper_bind = literal(upper_value, upper_type)
 
 
 class WithinGroup(ColumnElement[_T]):
@@ -4525,7 +4526,7 @@ class FunctionFilter(Generative, ColumnElement[_T]):
                 _ColumnExpressionArgument[Any],
             ]
         ] = None,
-        range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
+        range_: Optional[typing_Tuple[Optional[Any], Optional[Any]]] = None,
         rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
         groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None,
     ) -> Over[_T]:
index 050f94fd8087685aed5f7b0329bd3d2619ea6b88..10301d4a3795767c3cc0d08fc89f580bf823171a 100644 (file)
@@ -433,7 +433,7 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
         partition_by: Optional[_ByArgument] = None,
         order_by: Optional[_ByArgument] = None,
         rows: Optional[Tuple[Optional[int], Optional[int]]] = None,
-        range_: Optional[Tuple[Optional[int], Optional[int]]] = None,
+        range_: Optional[Tuple[Optional[Any], Optional[Any]]] = None,
         groups: Optional[Tuple[Optional[int], Optional[int]]] = None,
     ) -> Over[_T]:
         """Produce an OVER clause against this function.
index 6b21bb67fe2dfca9c4578f1b622503cc379b7c1c..c54bc20d3a177e9b4f4473abd5765d2290e9576d 100644 (file)
@@ -42,6 +42,7 @@ from ... import tuple_
 from ... import TupleType
 from ... import union
 from ... import values
+from ... import Float
 from ...exc import DatabaseError
 from ...exc import ProgrammingError
 
@@ -1913,13 +1914,14 @@ class WindowFunctionTest(fixtures.TablesTest):
             Column("id", Integer, primary_key=True),
             Column("col1", Integer),
             Column("col2", Integer),
+            Column("col3", Float),
         )
 
     @classmethod
     def insert_data(cls, connection):
         connection.execute(
             cls.tables.some_table.insert(),
-            [{"id": i, "col1": i, "col2": i * 5} for i in range(1, 50)],
+            [{"id": i, "col1": i, "col2": i * 5, "col3": i + 0.5} for i in range(1, 50)],
         )
 
     def test_window(self, connection):
@@ -1934,6 +1936,20 @@ class WindowFunctionTest(fixtures.TablesTest):
 
         eq_(rows, [(95,) for i in range(19)])
 
+    def test_window_range(self, connection):
+        some_table = self.tables.some_table
+        rows = connection.execute(
+            select(
+                func.max(some_table.c.col3).over(
+                    partition_by=[some_table.c.col3],
+                    order_by=[some_table.c.col3.asc()],
+                    range_=(-1.25, 1.25),
+                )
+            ).where(some_table.c.col1 < 20)
+        ).all()
+
+        eq_(rows, [(i + 1.5,) for i in range(19)])
+
     def test_window_rows_between_w_caching(self, connection):
         some_table = self.tables.some_table
 
index 28cdb03a9657136af7a004de4052617d0c816a63..4ce06636b442aab58e2ff3bb3df887a1434ea3f3 100644 (file)
@@ -61,6 +61,7 @@ table1 = table(
     column("myid", Integer),
     column("name", String),
     column("description", String),
+    column("myfloat", Float),
 )
 
 
@@ -816,6 +817,19 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             checkparams={"name_1": "foo", "param_1": 1, "param_2": 5},
         )
 
+        self.assert_compile(
+            select(
+                func.rank()
+                .filter(table1.c.name > "foo")
+                .over(range_=(-3.14, 2.71), partition_by=["myfloat"])
+            ),
+            "SELECT rank() FILTER (WHERE mytable.name > :name_1) "
+            "OVER (PARTITION BY mytable.myfloat RANGE BETWEEN :param_1 "
+            "PRECEDING AND :param_2 FOLLOWING) "
+            "AS anon_1 FROM mytable",
+            checkparams={"name_1": "foo", "param_1": 3.14, "param_2": 2.71},
+        )
+
     def test_funcfilter_windowing_range_positional(self):
         self.assert_compile(
             select(