]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix typing generics in PostgreSQL range types.
authorJim Bosch <jbosch@astro.princeton.edu>
Tue, 14 Nov 2023 21:19:31 +0000 (16:19 -0500)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 7 Feb 2024 18:11:51 +0000 (19:11 +0100)
Correctly type PostgreSQL RANGE and MULTIRANGE types as ``Range[T]``
and ``Sequence[Range[T]]``.
Introduced utility sequence ``MultiRange`` to allow better
interoperability of MULTIRANGE types.

Fixes #9736
Closes: #10625
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10625
Pull-request-sha: 2c17bc5f922a2bdb805a29e458184076ccc08055

Change-Id: I4f91d0233b29fd8101e67bdd4cd0aa2524ab788a
(cherry picked from commit 4006cb38e13ac471655f5f27102678ed8933ee60)

12 files changed:
doc/build/changelog/unreleased_20/9736.rst [new file with mode: 0644]
doc/build/dialects/postgresql.rst
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/pg8000.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/dialects/postgresql/ranges.py
setup.cfg
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_types.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py

diff --git a/doc/build/changelog/unreleased_20/9736.rst b/doc/build/changelog/unreleased_20/9736.rst
new file mode 100644 (file)
index 0000000..deb1703
--- /dev/null
@@ -0,0 +1,16 @@
+.. change::
+    :tags: postgresql, usecase
+    :tickets: 9736
+
+    Correctly type PostgreSQL RANGE and MULTIRANGE types as ``Range[T]``
+    and ``Sequence[Range[T]]``.
+    Introduced utility sequence :class:`_postgresql.MultiRange` to allow better
+    interoperability of MULTIRANGE types.
+
+.. change::
+    :tags: postgresql, usecase
+
+    Differentiate between INT4 and INT8 ranges and multi-ranges types when
+    inferring the database type from a :class:`_postgresql.Range` or
+    :class:`_postgresql.MultiRange` instance, preferring INT4 if the values
+    fit into it.
index 0575837185c205eca9be89376d9cbd88c95757e1..e822d069ce672f21ad77db295a6bd367092ec9bb 100644 (file)
@@ -238,6 +238,8 @@ dialect, **does not** support multirange datatypes.
 .. versionadded:: 2.0.17 Added multirange support for the pg8000 dialect.
    pg8000 1.29.8 or greater is required.
 
+.. versionadded:: 2.0.26 :class:`_postgresql.MultiRange` sequence added.
+
 The example below illustrates use of the :class:`_postgresql.TSMULTIRANGE`
 datatype::
 
@@ -260,6 +262,7 @@ datatype::
 
         id: Mapped[int] = mapped_column(primary_key=True)
         event_name: Mapped[str]
+        added: Mapped[datetime]
         in_session_periods: Mapped[List[Range[datetime]]] = mapped_column(TSMULTIRANGE)
 
 Illustrating insertion and selecting of a record::
@@ -294,6 +297,38 @@ Illustrating insertion and selecting of a record::
    a new list to the attribute, or use the :class:`.MutableList`
    type modifier.  See the section :ref:`mutable_toplevel` for background.
 
+.. _postgresql_multirange_list_use:
+
+Use of a MultiRange sequence to infer the multirange type
+"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+When using a multirange as a literal without specifying the type
+the utility :class:`_postgresql.MultiRange` sequence can be used::
+
+    from sqlalchemy import literal
+    from sqlalchemy.dialects.postgresql import MultiRange
+
+    with Session(engine) as session:
+        stmt = select(EventCalendar).where(
+            EventCalendar.added.op("<@")(
+                MultiRange(
+                    [
+                        Range(datetime(2023, 1, 1), datetime(2013, 3, 31)),
+                        Range(datetime(2023, 7, 1), datetime(2013, 9, 30)),
+                    ]
+                )
+            )
+        )
+        in_range = session.execute(stmt).all()
+
+    with engine.connect() as conn:
+        row = conn.scalar(select(literal(MultiRange([Range(2, 4)]))))
+        print(f"{row.lower} -> {row.upper}")
+
+Using a simple ``list`` instead of :class:`_postgresql.MultiRange` would require
+manually setting the type of the literal value to the appropriate multirange type.
+
+.. versionadded:: 2.0.26 :class:`_postgresql.MultiRange` sequence added.
 
 The available multirange datatypes are as follows:
 
@@ -416,6 +451,8 @@ construction arguments, are as follows:
 .. autoclass:: sqlalchemy.dialects.postgresql.AbstractRange
     :members: comparator_factory
 
+.. autoclass:: sqlalchemy.dialects.postgresql.AbstractSingleRange
+
 .. autoclass:: sqlalchemy.dialects.postgresql.AbstractMultiRange
 
 
@@ -529,6 +566,9 @@ construction arguments, are as follows:
 .. autoclass:: TSTZMULTIRANGE
 
 
+.. autoclass:: MultiRange
+
+
 PostgreSQL SQL Elements and Functions
 --------------------------------------
 
index 8dfa54d3aca4274d9ba6a73e37040fc0a6c1462a..17b14f4d05b64523d3c0036b6aa5eb6ff64c0f9e 100644 (file)
@@ -57,12 +57,14 @@ from .named_types import ENUM
 from .named_types import NamedType
 from .ranges import AbstractMultiRange
 from .ranges import AbstractRange
+from .ranges import AbstractSingleRange
 from .ranges import DATEMULTIRANGE
 from .ranges import DATERANGE
 from .ranges import INT4MULTIRANGE
 from .ranges import INT4RANGE
 from .ranges import INT8MULTIRANGE
 from .ranges import INT8RANGE
+from .ranges import MultiRange
 from .ranges import NUMMULTIRANGE
 from .ranges import NUMRANGE
 from .ranges import Range
index 2c460412c09e31faa69fcec851749886b3711a73..af097e283d32188ac1087ac33feb4d8db5bf40e3 100644 (file)
@@ -176,8 +176,6 @@ import decimal
 import json as _py_json
 import re
 import time
-from typing import cast
-from typing import TYPE_CHECKING
 
 from . import json
 from . import ranges
@@ -207,9 +205,6 @@ from ...util.concurrency import asyncio
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
-if TYPE_CHECKING:
-    from typing import Iterable
-
 
 class AsyncpgARRAY(PGARRAY):
     render_bind_cast = True
@@ -361,7 +356,7 @@ class AsyncpgCHAR(sqltypes.CHAR):
     render_bind_cast = True
 
 
-class _AsyncpgRange(ranges.AbstractRangeImpl):
+class _AsyncpgRange(ranges.AbstractSingleRangeImpl):
     def bind_processor(self, dialect):
         asyncpg_Range = dialect.dbapi.asyncpg.Range
 
@@ -415,10 +410,7 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
                     )
                 return value
 
-            return [
-                to_range(element)
-                for element in cast("Iterable[ranges.Range]", value)
-            ]
+            return [to_range(element) for element in value]
 
         return to_range
 
@@ -437,7 +429,7 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
                 return rvalue
 
             if value is not None:
-                value = [to_range(elem) for elem in value]
+                value = ranges.MultiRange(to_range(elem) for elem in value)
 
             return value
 
@@ -1050,7 +1042,7 @@ class PGDialect_asyncpg(PGDialect):
             OID: AsyncpgOID,
             REGCLASS: AsyncpgREGCLASS,
             sqltypes.CHAR: AsyncpgCHAR,
-            ranges.AbstractRange: _AsyncpgRange,
+            ranges.AbstractSingleRange: _AsyncpgRange,
             ranges.AbstractMultiRange: _AsyncpgMultiRange,
         },
     )
index fd7d9a37880e2e55c3332695d91952ba97069aa4..0151be0253daed70687a7e129dfc5f3261be9ca6 100644 (file)
@@ -253,7 +253,7 @@ class _PGOIDVECTOR(_SpaceVector, OIDVECTOR):
     pass
 
 
-class _Pg8000Range(ranges.AbstractRangeImpl):
+class _Pg8000Range(ranges.AbstractSingleRangeImpl):
     def bind_processor(self, dialect):
         pg8000_Range = dialect.dbapi.Range
 
@@ -304,15 +304,13 @@ class _Pg8000MultiRange(ranges.AbstractMultiRangeImpl):
         def to_multirange(value):
             if value is None:
                 return None
-
-            mr = []
-            for v in value:
-                mr.append(
+            else:
+                return ranges.MultiRange(
                     ranges.Range(
                         v.lower, v.upper, bounds=v.bounds, empty=v.is_empty
                     )
+                    for v in value
                 )
-            return mr
 
         return to_multirange
 
index df3d50e4867b38988ee75dac694bc41fd0ff5218..90177a43cebbe0effbc2a33bc0e7e9850d9c4470 100644 (file)
@@ -164,7 +164,7 @@ class _PGBoolean(sqltypes.Boolean):
     render_bind_cast = True
 
 
-class _PsycopgRange(ranges.AbstractRangeImpl):
+class _PsycopgRange(ranges.AbstractSingleRangeImpl):
     def bind_processor(self, dialect):
         psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
 
@@ -220,8 +220,10 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl):
 
     def result_processor(self, dialect, coltype):
         def to_range(value):
-            if value is not None:
-                value = [
+            if value is None:
+                return None
+            else:
+                return ranges.MultiRange(
                     ranges.Range(
                         elem._lower,
                         elem._upper,
@@ -229,9 +231,7 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl):
                         empty=not elem._bounds,
                     )
                     for elem in value
-                ]
-
-            return value
+                )
 
         return to_range
 
@@ -288,7 +288,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
             sqltypes.Integer: _PGInteger,
             sqltypes.SmallInteger: _PGSmallInteger,
             sqltypes.BigInteger: _PGBigInteger,
-            ranges.AbstractRange: _PsycopgRange,
+            ranges.AbstractSingleRange: _PsycopgRange,
             ranges.AbstractMultiRange: _PsycopgMultiRange,
         },
     )
index 0b89149ec9d6e48711130bc26cb3d308b44a4653..9bf2e4933618ce29742855bb5fbb498ea38d9729 100644 (file)
@@ -513,7 +513,7 @@ class _PGJSONB(JSONB):
         return None
 
 
-class _Psycopg2Range(ranges.AbstractRangeImpl):
+class _Psycopg2Range(ranges.AbstractSingleRangeImpl):
     _psycopg2_range_cls = "none"
 
     def bind_processor(self, dialect):
index 980f1449359b57edba28c843833e526928cdce7e..b793ca49f1852afd1ad94afc3bdffdd031f5fd1d 100644 (file)
@@ -15,8 +15,10 @@ from decimal import Decimal
 from typing import Any
 from typing import cast
 from typing import Generic
+from typing import List
 from typing import Optional
 from typing import overload
+from typing import Sequence
 from typing import Tuple
 from typing import Type
 from typing import TYPE_CHECKING
@@ -152,8 +154,8 @@ class Range(Generic[_T]):
         return not self.empty and self.upper is None
 
     @property
-    def __sa_type_engine__(self) -> AbstractRange[Range[_T]]:
-        return AbstractRange()
+    def __sa_type_engine__(self) -> AbstractSingleRange[_T]:
+        return AbstractSingleRange()
 
     def _contains_value(self, value: _T) -> bool:
         """Return True if this range contains the given value."""
@@ -708,15 +710,34 @@ class Range(Generic[_T]):
         return f"{b0}{l},{r}{b1}"
 
 
-class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
-    """
-    Base for PostgreSQL RANGE types.
+class MultiRange(List[Range[_T]]):
+    """Represents a multirange sequence.
+
+    This list subclass is an utility to allow automatic type inference of
+    the proper multi-range SQL type depending on the single range values.
+    This is useful when operating on literal multi-ranges::
+
+        import sqlalchemy as sa
+        from sqlalchemy.dialects.postgresql import MultiRange, Range
+
+        value = literal(MultiRange([Range(2, 4)]))
+
+        select(tbl).where(tbl.c.value.op("@")(MultiRange([Range(-3, 7)])))
+
+    .. versionadded:: 2.0.26
 
     .. seealso::
 
-        `PostgreSQL range functions <https://www.postgresql.org/docs/current/static/functions-range.html>`_
+        - :ref:`postgresql_multirange_list_use`.
+    """
 
-    """  # noqa: E501
+    @property
+    def __sa_type_engine__(self) -> AbstractMultiRange[_T]:
+        return AbstractMultiRange()
+
+
+class AbstractRange(sqltypes.TypeEngine[_T]):
+    """Base class for single and multi Range SQL types."""
 
     render_bind_cast = True
 
@@ -742,7 +763,10 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
         and also render as ``INT4RANGE`` in SQL and DDL.
 
         """
-        if issubclass(cls, AbstractRangeImpl) and cls is not self.__class__:
+        if (
+            issubclass(cls, (AbstractSingleRangeImpl, AbstractMultiRangeImpl))
+            and cls is not self.__class__
+        ):
             # two ways to do this are:  1. create a new type on the fly
             # or 2. have AbstractRangeImpl(visit_name) constructor and a
             # visit_abstract_range_impl() method in the PG compiler.
@@ -761,21 +785,6 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
         else:
             return super().adapt(cls)
 
-    def _resolve_for_literal(self, value: Any) -> Any:
-        spec = value.lower if value.lower is not None else value.upper
-
-        if isinstance(spec, int):
-            return INT8RANGE()
-        elif isinstance(spec, (Decimal, float)):
-            return NUMRANGE()
-        elif isinstance(spec, datetime):
-            return TSRANGE() if not spec.tzinfo else TSTZRANGE()
-        elif isinstance(spec, date):
-            return DATERANGE()
-        else:
-            # empty Range, SQL datatype can't be determined here
-            return sqltypes.NULLTYPE
-
     class comparator_factory(TypeEngine.Comparator[Range[Any]]):
         """Define comparison operations for range types."""
 
@@ -857,91 +866,164 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
             return self.expr.operate(operators.mul, other)
 
 
-class AbstractRangeImpl(AbstractRange[Range[_T]]):
-    """Marker for AbstractRange that will apply a subclass-specific
+class AbstractSingleRange(AbstractRange[Range[_T]]):
+    """Base for PostgreSQL RANGE types.
+
+    These are types that return a single :class:`_postgresql.Range` object.
+
+    .. seealso::
+
+        `PostgreSQL range functions <https://www.postgresql.org/docs/current/static/functions-range.html>`_
+
+    """  # noqa: E501
+
+    __abstract__ = True
+
+    def _resolve_for_literal(self, value: Range[Any]) -> Any:
+        spec = value.lower if value.lower is not None else value.upper
+
+        if isinstance(spec, int):
+            # pg is unreasonably picky here: the query
+            # "select 1::INTEGER <@ '[1, 4)'::INT8RANGE" raises
+            # "operator does not exist: integer <@ int8range" as of pg 16
+            if _is_int32(value):
+                return INT4RANGE()
+            else:
+                return INT8RANGE()
+        elif isinstance(spec, (Decimal, float)):
+            return NUMRANGE()
+        elif isinstance(spec, datetime):
+            return TSRANGE() if not spec.tzinfo else TSTZRANGE()
+        elif isinstance(spec, date):
+            return DATERANGE()
+        else:
+            # empty Range, SQL datatype can't be determined here
+            return sqltypes.NULLTYPE
+
+
+class AbstractSingleRangeImpl(AbstractSingleRange[_T]):
+    """Marker for AbstractSingleRange that will apply a subclass-specific
     adaptation"""
 
 
-class AbstractMultiRange(AbstractRange[Range[_T]]):
-    """base for PostgreSQL MULTIRANGE types"""
+class AbstractMultiRange(AbstractRange[Sequence[Range[_T]]]):
+    """Base for PostgreSQL MULTIRANGE types.
+
+    these are types that return a sequence of :class:`_postgresql.Range`
+    objects.
+
+    """
 
     __abstract__ = True
 
+    def _resolve_for_literal(self, value: Sequence[Range[Any]]) -> Any:
+        if not value:
+            # empty MultiRange, SQL datatype can't be determined here
+            return sqltypes.NULLTYPE
+        first = value[0]
+        spec = first.lower if first.lower is not None else first.upper
 
-class AbstractMultiRangeImpl(
-    AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]]
-):
-    """Marker for AbstractRange that will apply a subclass-specific
+        if isinstance(spec, int):
+            # pg is unreasonably picky here: the query
+            # "select 1::INTEGER <@ '{[1, 4),[6,19)}'::INT8MULTIRANGE" raises
+            # "operator does not exist: integer <@ int8multirange" as of pg 16
+            if all(_is_int32(r) for r in value):
+                return INT4MULTIRANGE()
+            else:
+                return INT8MULTIRANGE()
+        elif isinstance(spec, (Decimal, float)):
+            return NUMMULTIRANGE()
+        elif isinstance(spec, datetime):
+            return TSMULTIRANGE() if not spec.tzinfo else TSTZMULTIRANGE()
+        elif isinstance(spec, date):
+            return DATEMULTIRANGE()
+        else:
+            # empty Range, SQL datatype can't be determined here
+            return sqltypes.NULLTYPE
+
+
+class AbstractMultiRangeImpl(AbstractMultiRange[_T]):
+    """Marker for AbstractMultiRange that will apply a subclass-specific
     adaptation"""
 
 
-class INT4RANGE(AbstractRange[Range[int]]):
+class INT4RANGE(AbstractSingleRange[int]):
     """Represent the PostgreSQL INT4RANGE type."""
 
     __visit_name__ = "INT4RANGE"
 
 
-class INT8RANGE(AbstractRange[Range[int]]):
+class INT8RANGE(AbstractSingleRange[int]):
     """Represent the PostgreSQL INT8RANGE type."""
 
     __visit_name__ = "INT8RANGE"
 
 
-class NUMRANGE(AbstractRange[Range[Decimal]]):
+class NUMRANGE(AbstractSingleRange[Decimal]):
     """Represent the PostgreSQL NUMRANGE type."""
 
     __visit_name__ = "NUMRANGE"
 
 
-class DATERANGE(AbstractRange[Range[date]]):
+class DATERANGE(AbstractSingleRange[date]):
     """Represent the PostgreSQL DATERANGE type."""
 
     __visit_name__ = "DATERANGE"
 
 
-class TSRANGE(AbstractRange[Range[datetime]]):
+class TSRANGE(AbstractSingleRange[datetime]):
     """Represent the PostgreSQL TSRANGE type."""
 
     __visit_name__ = "TSRANGE"
 
 
-class TSTZRANGE(AbstractRange[Range[datetime]]):
+class TSTZRANGE(AbstractSingleRange[datetime]):
     """Represent the PostgreSQL TSTZRANGE type."""
 
     __visit_name__ = "TSTZRANGE"
 
 
-class INT4MULTIRANGE(AbstractMultiRange[Range[int]]):
+class INT4MULTIRANGE(AbstractMultiRange[int]):
     """Represent the PostgreSQL INT4MULTIRANGE type."""
 
     __visit_name__ = "INT4MULTIRANGE"
 
 
-class INT8MULTIRANGE(AbstractMultiRange[Range[int]]):
+class INT8MULTIRANGE(AbstractMultiRange[int]):
     """Represent the PostgreSQL INT8MULTIRANGE type."""
 
     __visit_name__ = "INT8MULTIRANGE"
 
 
-class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]):
+class NUMMULTIRANGE(AbstractMultiRange[Decimal]):
     """Represent the PostgreSQL NUMMULTIRANGE type."""
 
     __visit_name__ = "NUMMULTIRANGE"
 
 
-class DATEMULTIRANGE(AbstractMultiRange[Range[date]]):
+class DATEMULTIRANGE(AbstractMultiRange[date]):
     """Represent the PostgreSQL DATEMULTIRANGE type."""
 
     __visit_name__ = "DATEMULTIRANGE"
 
 
-class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]):
+class TSMULTIRANGE(AbstractMultiRange[datetime]):
     """Represent the PostgreSQL TSRANGE type."""
 
     __visit_name__ = "TSMULTIRANGE"
 
 
-class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]):
+class TSTZMULTIRANGE(AbstractMultiRange[datetime]):
     """Represent the PostgreSQL TSTZRANGE type."""
 
     __visit_name__ = "TSTZMULTIRANGE"
+
+
+_max_int_32 = 2**31 - 1
+_min_int_32 = -(2**31)
+
+
+def _is_int32(r: Range[int]) -> bool:
+    return (r.lower is None or _min_int_32 <= r.lower <= _max_int_32) and (
+        r.upper is None or _min_int_32 <= r.upper <= _max_int_32
+    )
index d51e4d854cc829a62546059c062ef2f67779cbe2..2a8a68132adbd2cf4fd7316675bc071927f33cfd 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -181,7 +181,7 @@ mariadb_connector = mariadb+mariadbconnector://scott:tiger@127.0.0.1:3306/test
 mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes
 mssql_async = mssql+aioodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes
 pymssql = mssql+pymssql://scott:tiger^5HHH@mssql2017:1433/test
-docker_mssql = mssql+pyodbc://scott:tiger^5HHH@127.0.0.1:1433/test?driver=ODBC+Driver+18+for+SQL+Server
+docker_mssql = mssql+pyodbc://scott:tiger^5HHH@127.0.0.1:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes
 oracle = oracle+cx_oracle://scott:tiger@oracle18c/xe
 cxoracle = oracle+cx_oracle://scott:tiger@oracle18c/xe
 oracledb = oracle+oracledb://scott:tiger@oracle18c/xe
index 005e60eaa14fd33eace35613d8c6a3c88b123986..10144d63a6969714d181202b310add218a910c8b 100644 (file)
@@ -52,6 +52,7 @@ from sqlalchemy.dialects.postgresql import TSQUERY
 from sqlalchemy.dialects.postgresql import TSRANGE
 from sqlalchemy.dialects.postgresql.base import PGDialect
 from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
+from sqlalchemy.dialects.postgresql.ranges import MultiRange
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import clear_mappers
 from sqlalchemy.orm import Session
@@ -2588,7 +2589,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
 
         self.assert_compile(expr, expected)
 
-    def test_custom_object_hook(self):
+    def test_range_custom_object_hook(self):
         # See issue #8884
         from datetime import date
 
@@ -2608,6 +2609,30 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "WHERE usages.date <@ %(date_1)s::DATERANGE",
         )
 
+    def test_multirange_custom_object_hook(self):
+        from datetime import date
+
+        usages = table(
+            "usages",
+            column("id", Integer),
+            column("date", Date),
+            column("amount", Integer),
+        )
+        period = MultiRange(
+            [
+                Range(date(2022, 1, 1), (2023, 1, 1)),
+                Range(date(2024, 1, 1), (2025, 1, 1)),
+            ]
+        )
+        stmt = select(func.sum(usages.c.amount)).where(
+            usages.c.date.op("<@")(period)
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT sum(usages.amount) AS sum_1 FROM usages "
+            "WHERE usages.date <@ %(date_1)s::DATEMULTIRANGE",
+        )
+
     def test_bitwise_xor(self):
         c1 = column("c1", Integer)
         c2 = column("c2", Integer)
index 2088436eebf1cace8be929cfa23b251c6eb6a29b..a5093c0bc90bca004864b580a591dced303b21f0 100644 (file)
@@ -73,6 +73,7 @@ from sqlalchemy.dialects.postgresql import TSMULTIRANGE
 from sqlalchemy.dialects.postgresql import TSRANGE
 from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE
 from sqlalchemy.dialects.postgresql import TSTZRANGE
+from sqlalchemy.dialects.postgresql.ranges import MultiRange
 from sqlalchemy.exc import CompileError
 from sqlalchemy.exc import DBAPIError
 from sqlalchemy.orm import declarative_base
@@ -92,6 +93,7 @@ from sqlalchemy.testing.assertions import AssertsExecutionResults
 from sqlalchemy.testing.assertions import ComparesTables
 from sqlalchemy.testing.assertions import eq_
 from sqlalchemy.testing.assertions import is_
+from sqlalchemy.testing.assertions import ne_
 from sqlalchemy.testing.assertsql import RegexSQL
 from sqlalchemy.testing.schema import pep435_enum
 from sqlalchemy.testing.suite import test_types as suite
@@ -3887,6 +3889,53 @@ class HStoreRoundTripTest(fixtures.TablesTest):
             eq_(s.query(Data.data, Data).all(), [(d.data, d)])
 
 
+class RangeMiscTests(fixtures.TestBase):
+    @testing.combinations(
+        (Range(2, 7), INT4RANGE),
+        (Range(-10, 7), INT4RANGE),
+        (Range(None, -7), INT4RANGE),
+        (Range(33, None), INT4RANGE),
+        (Range(-2147483648, 2147483647), INT4RANGE),
+        (Range(-2147483648 - 1, 2147483647), INT8RANGE),
+        (Range(-2147483648, 2147483647 + 1), INT8RANGE),
+        (Range(-2147483648 - 1, None), INT8RANGE),
+        (Range(None, 2147483647 + 1), INT8RANGE),
+    )
+    def test_resolve_for_literal(self, obj, type_):
+        """This tests that the int4 / int8 version is selected correctly by
+        _resolve_for_literal."""
+        lit = literal(obj)
+        eq_(type(lit.type), type_)
+
+    @testing.combinations(
+        (Range(2, 7), INT4MULTIRANGE),
+        (Range(-10, 7), INT4MULTIRANGE),
+        (Range(None, -7), INT4MULTIRANGE),
+        (Range(33, None), INT4MULTIRANGE),
+        (Range(-2147483648, 2147483647), INT4MULTIRANGE),
+        (Range(-2147483648 - 1, 2147483647), INT8MULTIRANGE),
+        (Range(-2147483648, 2147483647 + 1), INT8MULTIRANGE),
+        (Range(-2147483648 - 1, None), INT8MULTIRANGE),
+        (Range(None, 2147483647 + 1), INT8MULTIRANGE),
+    )
+    def test_resolve_for_literal_multi(self, obj, type_):
+        """This tests that the int4 / int8 version is selected correctly by
+        _resolve_for_literal."""
+        list_ = MultiRange([Range(-1, 1), obj, Range(7, 100)])
+        lit = literal(list_)
+        eq_(type(lit.type), type_)
+
+    def test_multirange_sequence(self):
+        plain = [Range(-1, 1), Range(42, 43), Range(7, 100)]
+        mr = MultiRange(plain)
+        is_true(issubclass(MultiRange, list))
+        is_true(isinstance(mr, list))
+        eq_(mr, plain)
+        eq_(str(mr), str(plain))
+        eq_(repr(mr), repr(plain))
+        ne_(mr, plain[1:])
+
+
 class _RangeTests:
     _col_type = None
     "The concrete range class these tests are for."
@@ -4641,11 +4690,21 @@ class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest):
         Brought up in #8540.
 
         """
+        # see also CompileTest::test_range_custom_object_hook
         data_obj = self._data_obj()
         stmt = select(literal(data_obj, type_=self._col_type))
         round_trip = connection.scalar(stmt)
         eq_(round_trip, data_obj)
 
+    def test_auto_cast_back_to_type_without_type(self, connection):
+        """use _resolve_for_literal to cast"""
+        # see also CompileTest::test_range_custom_object_hook
+        data_obj = self._data_obj()
+        lit = literal(data_obj)
+        round_trip = connection.scalar(select(lit))
+        eq_(round_trip, data_obj)
+        eq_(type(lit.type), self._col_type)
+
     def test_actual_type(self):
         eq_(str(self._col_type()), self._col_str)
 
@@ -5140,10 +5199,17 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
         )
 
 
-class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
+class _MultiRangeTypeRoundTrip(fixtures.TablesTest, _RangeTests):
     __requires__ = ("multirange_types",)
     __backend__ = True
 
+    @testing.fixture(params=(True, False), ids=["multirange", "plain_list"])
+    def data_obj(self, request):
+        if request.param:
+            return MultiRange(self._data_obj())
+        else:
+            return list(self._data_obj())
+
     @classmethod
     def define_tables(cls, metadata):
         # no reason ranges shouldn't be primary keys,
@@ -5155,7 +5221,7 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
         )
         cls.col = table.c.range
 
-    def test_auto_cast_back_to_type(self, connection):
+    def test_auto_cast_back_to_type(self, connection, data_obj):
         """test that a straight pass of the range type without any context
         will send appropriate casting info so that the driver can round
         trip it.
@@ -5170,11 +5236,29 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
         Brought up in #8540.
 
         """
-        data_obj = self._data_obj()
+        # see also CompileTest::test_multirange_custom_object_hook
         stmt = select(literal(data_obj, type_=self._col_type))
         round_trip = connection.scalar(stmt)
         eq_(round_trip, data_obj)
 
+    def test_auto_cast_back_to_type_without_type(self, connection):
+        """use _resolve_for_literal to cast"""
+        # see also CompileTest::test_multirange_custom_object_hook
+        data_obj = MultiRange(self._data_obj())
+        lit = literal(data_obj)
+        round_trip = connection.scalar(select(lit))
+        eq_(round_trip, data_obj)
+        eq_(type(lit.type), self._col_type)
+
+    @testing.fails("no automatic adaptation of plain list")
+    def test_auto_cast_back_to_type_without_type_plain_list(self, connection):
+        """use _resolve_for_literal to cast"""
+        # see also CompileTest::test_multirange_custom_object_hook
+        data_obj = list(self._data_obj())
+        lit = literal(data_obj)
+        r = connection.scalar(select(lit))
+        eq_(type(r), list)
+
     def test_actual_type(self):
         eq_(str(self._col_type()), self._col_str)
 
@@ -5188,12 +5272,12 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
     def _assert_data(self, conn):
         data = conn.execute(select(self.tables.data_table.c.range)).fetchall()
         eq_(data, [(self._data_obj(),)])
+        eq_(type(data[0][0]), MultiRange)
 
-    def test_textual_round_trip_w_dialect_type(self, connection):
+    def test_textual_round_trip_w_dialect_type(self, connection, data_obj):
         """test #8690"""
         data_table = self.tables.data_table
 
-        data_obj = self._data_obj()
         connection.execute(
             self.tables.data_table.insert(), {"range": data_obj}
         )
@@ -5206,9 +5290,9 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
 
         eq_(data_obj, v2)
 
-    def test_insert_obj(self, connection):
+    def test_insert_obj(self, connection, data_obj):
         connection.execute(
-            self.tables.data_table.insert(), {"range": self._data_obj()}
+            self.tables.data_table.insert(), {"range": data_obj}
         )
         self._assert_data(connection)
 
@@ -5229,6 +5313,7 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
         range_ = self.tables.data_table.c.range
         data = connection.execute(select(range_ + range_)).fetchall()
         eq_(data, [(self._data_obj(),)])
+        eq_(type(data[0][0]), MultiRange)
 
     @testing.requires.psycopg_or_pg8000_compatibility
     def test_intersection_result_text(self, connection):
@@ -5240,6 +5325,7 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
         range_ = self.tables.data_table.c.range
         data = connection.execute(select(range_ * range_)).fetchall()
         eq_(data, [(self._data_obj(),)])
+        eq_(type(data[0][0]), MultiRange)
 
     @testing.requires.psycopg_or_pg8000_compatibility
     def test_difference_result_text(self, connection):
@@ -5251,6 +5337,7 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
         range_ = self.tables.data_table.c.range
         data = connection.execute(select(range_ - range_)).fetchall()
         eq_(data, [([],)])
+        eq_(type(data[0][0]), MultiRange)
 
 
 class _Int4MultiRangeTests:
@@ -5261,11 +5348,7 @@ class _Int4MultiRangeTests:
         return "{[1,2), [3, 5), [9, 12)}"
 
     def _data_obj(self):
-        return [
-            Range(1, 2),
-            Range(3, 5),
-            Range(9, 12),
-        ]
+        return [Range(1, 2), Range(3, 5), Range(9, 12)]
 
 
 class _Int8MultiRangeTests:
@@ -5465,6 +5548,17 @@ class DateTimeTZRMultiangeRoundTripTest(
     pass
 
 
+class MultiRangeSequenceTest(fixtures.TestBase):
+    def test_methods(self):
+        plain = [Range(1, 3), Range(5, 9)]
+        multi = MultiRange(plain)
+        is_true(isinstance(multi, list))
+        eq_(multi, plain)
+        ne_(multi, plain[:1])
+        eq_(str(multi), str(plain))
+        eq_(repr(multi), repr(plain))
+
+
 class JSONTest(AssertsCompiledSQL, fixtures.TestBase):
     __dialect__ = "postgresql"
 
index 4567daa38665ddafe36597d0cd283ac5ae067875..678d22b71f97562ec10e598f6377b4127bcad27d 100644 (file)
@@ -12,14 +12,17 @@ from sqlalchemy import Text
 from sqlalchemy import UniqueConstraint
 from sqlalchemy.dialects.postgresql import ARRAY
 from sqlalchemy.dialects.postgresql import array
+from sqlalchemy.dialects.postgresql import DATERANGE
 from sqlalchemy.dialects.postgresql import insert
+from sqlalchemy.dialects.postgresql import INT4RANGE
+from sqlalchemy.dialects.postgresql import INT8MULTIRANGE
 from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE
 from sqlalchemy.dialects.postgresql import UUID
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 
-
 # test #6402
 
 c1 = Column(UUID())
@@ -77,3 +80,19 @@ insert(Test).on_conflict_do_nothing(
 ).on_conflict_do_update(
     unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22
 ).excluded.foo.desc()
+
+
+# EXPECTED_TYPE: Column[Range[int]]
+reveal_type(Column(INT4RANGE()))
+# EXPECTED_TYPE: Column[Range[datetime.date]]
+reveal_type(Column("foo", DATERANGE()))
+# EXPECTED_TYPE: Column[Sequence[Range[int]]]
+reveal_type(Column(INT8MULTIRANGE()))
+# EXPECTED_TYPE: Column[Sequence[Range[datetime.datetime]]]
+reveal_type(Column("foo", TSTZMULTIRANGE()))
+
+
+range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE()))
+
+# EXPECTED_TYPE: Select[Tuple[Range[int], Sequence[Range[int]]]]
+reveal_type(range_col_stmt)