]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
account for sql server in test_window_range
authorViolet Folino Gallo <48537601+galloviolet@users.noreply.github.com>
Wed, 9 Jul 2025 23:50:55 +0000 (16:50 -0700)
committerViolet Folino Gallo <48537601+galloviolet@users.noreply.github.com>
Wed, 9 Jul 2025 23:50:55 +0000 (16:50 -0700)
lib/sqlalchemy/testing/suite/test_select.py

index 0b55a57c13558f4b6bd7959a715c6d4ca87c5faa..97833149b5d9f5f3fa3526409bc620963a82d766 100644 (file)
@@ -45,6 +45,7 @@ from ... import union
 from ... import values
 from ...exc import DatabaseError
 from ...exc import ProgrammingError
+from ...sql.elements import _FrameClauseType
 
 
 class CollateTest(fixtures.TablesTest):
@@ -1945,17 +1946,29 @@ class WindowFunctionTest(fixtures.TablesTest):
 
     def test_window_range(self, connection):
         some_table = self.tables.some_table
-        rows = connection.execute(
-            select(
-                func.max(some_table.c.col3).over(
-                    partition_by=[some_table.c.col3],
-                    order_by=[some_table.c.col3.asc()],
-                    range_=(-1.25, 1.25),
+        # SQL Server only allows UNBOUNDED and CURRENT ROW in the RANGE clause
+        if config.db.dialect in ["mssql+aiodbc", "mssql+pymssql", "mssql+pyodbc"]:
+            rows = connection.execute(
+                select(
+                    func.max(some_table.c.col1).over(
+                        partition_by=[some_table.c.col2],
+                        order_by=[some_table.c.col2.asc()],
+                        range_=(_FrameClauseType.RANGE_UNBOUNDED, _FrameClauseType.RANGE_CURRENT)
+                    )
                 )
-            ).where(some_table.c.col1 < 20)
-        ).all()
+            )
+        else:
+            rows = connection.execute(
+                select(
+                    func.max(some_table.c.col3).over(
+                        partition_by=[some_table.c.col3],
+                        order_by=[some_table.c.col3.asc()],
+                        range_=(-1.25, 1.25),
+                    )
+                ).where(some_table.c.col1 < 20)
+            ).all()
 
-        eq_(rows, [(i + 1.5,) for i in range(19)])
+            eq_(rows, [(i + 1.5,) for i in range(19)])
 
     def test_window_rows_between_w_caching(self, connection):
         some_table = self.tables.some_table