From 14870221fbad2acf1e9f35132bc3e23872357a69 Mon Sep 17 00:00:00 2001 From: Jim Bosch Date: Tue, 14 Nov 2023 16:19:31 -0500 Subject: [PATCH] Fix typing generics in PostgreSQL range types. Correctly type PostgreSQL RANGE and MULTIRANGE types as ``Range[T]`` and ``Sequence[Range[T]]``. Introduced utility sequence ``MultiRange`` to allow better interoperability of MULTIRANGE types. Fixes #9736 Closes: #10625 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10625 Pull-request-sha: 2c17bc5f922a2bdb805a29e458184076ccc08055 Change-Id: I4f91d0233b29fd8101e67bdd4cd0aa2524ab788a (cherry picked from commit 4006cb38e13ac471655f5f27102678ed8933ee60) --- doc/build/changelog/unreleased_20/9736.rst | 16 ++ doc/build/dialects/postgresql.rst | 40 +++++ .../dialects/postgresql/__init__.py | 2 + lib/sqlalchemy/dialects/postgresql/asyncpg.py | 16 +- lib/sqlalchemy/dialects/postgresql/pg8000.py | 10 +- lib/sqlalchemy/dialects/postgresql/psycopg.py | 14 +- .../dialects/postgresql/psycopg2.py | 2 +- lib/sqlalchemy/dialects/postgresql/ranges.py | 168 +++++++++++++----- setup.cfg | 2 +- test/dialect/postgresql/test_compiler.py | 27 ++- test/dialect/postgresql/test_types.py | 118 ++++++++++-- .../dialects/postgresql/pg_stuff.py | 21 ++- 12 files changed, 352 insertions(+), 84 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/9736.rst diff --git a/doc/build/changelog/unreleased_20/9736.rst b/doc/build/changelog/unreleased_20/9736.rst new file mode 100644 index 0000000000..deb1703d87 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9736.rst @@ -0,0 +1,16 @@ +.. change:: + :tags: postgresql, usecase + :tickets: 9736 + + Correctly type PostgreSQL RANGE and MULTIRANGE types as ``Range[T]`` + and ``Sequence[Range[T]]``. + Introduced utility sequence :class:`_postgresql.MultiRange` to allow better + interoperability of MULTIRANGE types. + +.. change:: + :tags: postgresql, usecase + + Differentiate between INT4 and INT8 ranges and multi-ranges types when + inferring the database type from a :class:`_postgresql.Range` or + :class:`_postgresql.MultiRange` instance, preferring INT4 if the values + fit into it. diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index 0575837185..e822d069ce 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -238,6 +238,8 @@ dialect, **does not** support multirange datatypes. .. versionadded:: 2.0.17 Added multirange support for the pg8000 dialect. pg8000 1.29.8 or greater is required. +.. versionadded:: 2.0.26 :class:`_postgresql.MultiRange` sequence added. + The example below illustrates use of the :class:`_postgresql.TSMULTIRANGE` datatype:: @@ -260,6 +262,7 @@ datatype:: id: Mapped[int] = mapped_column(primary_key=True) event_name: Mapped[str] + added: Mapped[datetime] in_session_periods: Mapped[List[Range[datetime]]] = mapped_column(TSMULTIRANGE) Illustrating insertion and selecting of a record:: @@ -294,6 +297,38 @@ Illustrating insertion and selecting of a record:: a new list to the attribute, or use the :class:`.MutableList` type modifier. See the section :ref:`mutable_toplevel` for background. +.. _postgresql_multirange_list_use: + +Use of a MultiRange sequence to infer the multirange type +""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + +When using a multirange as a literal without specifying the type +the utility :class:`_postgresql.MultiRange` sequence can be used:: + + from sqlalchemy import literal + from sqlalchemy.dialects.postgresql import MultiRange + + with Session(engine) as session: + stmt = select(EventCalendar).where( + EventCalendar.added.op("<@")( + MultiRange( + [ + Range(datetime(2023, 1, 1), datetime(2013, 3, 31)), + Range(datetime(2023, 7, 1), datetime(2013, 9, 30)), + ] + ) + ) + ) + in_range = session.execute(stmt).all() + + with engine.connect() as conn: + row = conn.scalar(select(literal(MultiRange([Range(2, 4)])))) + print(f"{row.lower} -> {row.upper}") + +Using a simple ``list`` instead of :class:`_postgresql.MultiRange` would require +manually setting the type of the literal value to the appropriate multirange type. + +.. versionadded:: 2.0.26 :class:`_postgresql.MultiRange` sequence added. The available multirange datatypes are as follows: @@ -416,6 +451,8 @@ construction arguments, are as follows: .. autoclass:: sqlalchemy.dialects.postgresql.AbstractRange :members: comparator_factory +.. autoclass:: sqlalchemy.dialects.postgresql.AbstractSingleRange + .. autoclass:: sqlalchemy.dialects.postgresql.AbstractMultiRange @@ -529,6 +566,9 @@ construction arguments, are as follows: .. autoclass:: TSTZMULTIRANGE +.. autoclass:: MultiRange + + PostgreSQL SQL Elements and Functions -------------------------------------- diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 8dfa54d3ac..17b14f4d05 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -57,12 +57,14 @@ from .named_types import ENUM from .named_types import NamedType from .ranges import AbstractMultiRange from .ranges import AbstractRange +from .ranges import AbstractSingleRange from .ranges import DATEMULTIRANGE from .ranges import DATERANGE from .ranges import INT4MULTIRANGE from .ranges import INT4RANGE from .ranges import INT8MULTIRANGE from .ranges import INT8RANGE +from .ranges import MultiRange from .ranges import NUMMULTIRANGE from .ranges import NUMRANGE from .ranges import Range diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 2c460412c0..af097e283d 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -176,8 +176,6 @@ import decimal import json as _py_json import re import time -from typing import cast -from typing import TYPE_CHECKING from . import json from . import ranges @@ -207,9 +205,6 @@ from ...util.concurrency import asyncio from ...util.concurrency import await_fallback from ...util.concurrency import await_only -if TYPE_CHECKING: - from typing import Iterable - class AsyncpgARRAY(PGARRAY): render_bind_cast = True @@ -361,7 +356,7 @@ class AsyncpgCHAR(sqltypes.CHAR): render_bind_cast = True -class _AsyncpgRange(ranges.AbstractRangeImpl): +class _AsyncpgRange(ranges.AbstractSingleRangeImpl): def bind_processor(self, dialect): asyncpg_Range = dialect.dbapi.asyncpg.Range @@ -415,10 +410,7 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl): ) return value - return [ - to_range(element) - for element in cast("Iterable[ranges.Range]", value) - ] + return [to_range(element) for element in value] return to_range @@ -437,7 +429,7 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl): return rvalue if value is not None: - value = [to_range(elem) for elem in value] + value = ranges.MultiRange(to_range(elem) for elem in value) return value @@ -1050,7 +1042,7 @@ class PGDialect_asyncpg(PGDialect): OID: AsyncpgOID, REGCLASS: AsyncpgREGCLASS, sqltypes.CHAR: AsyncpgCHAR, - ranges.AbstractRange: _AsyncpgRange, + ranges.AbstractSingleRange: _AsyncpgRange, ranges.AbstractMultiRange: _AsyncpgMultiRange, }, ) diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index fd7d9a3788..0151be0253 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -253,7 +253,7 @@ class _PGOIDVECTOR(_SpaceVector, OIDVECTOR): pass -class _Pg8000Range(ranges.AbstractRangeImpl): +class _Pg8000Range(ranges.AbstractSingleRangeImpl): def bind_processor(self, dialect): pg8000_Range = dialect.dbapi.Range @@ -304,15 +304,13 @@ class _Pg8000MultiRange(ranges.AbstractMultiRangeImpl): def to_multirange(value): if value is None: return None - - mr = [] - for v in value: - mr.append( + else: + return ranges.MultiRange( ranges.Range( v.lower, v.upper, bounds=v.bounds, empty=v.is_empty ) + for v in value ) - return mr return to_multirange diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index df3d50e486..90177a43ce 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -164,7 +164,7 @@ class _PGBoolean(sqltypes.Boolean): render_bind_cast = True -class _PsycopgRange(ranges.AbstractRangeImpl): +class _PsycopgRange(ranges.AbstractSingleRangeImpl): def bind_processor(self, dialect): psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range @@ -220,8 +220,10 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl): def result_processor(self, dialect, coltype): def to_range(value): - if value is not None: - value = [ + if value is None: + return None + else: + return ranges.MultiRange( ranges.Range( elem._lower, elem._upper, @@ -229,9 +231,7 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl): empty=not elem._bounds, ) for elem in value - ] - - return value + ) return to_range @@ -288,7 +288,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): sqltypes.Integer: _PGInteger, sqltypes.SmallInteger: _PGSmallInteger, sqltypes.BigInteger: _PGBigInteger, - ranges.AbstractRange: _PsycopgRange, + ranges.AbstractSingleRange: _PsycopgRange, ranges.AbstractMultiRange: _PsycopgMultiRange, }, ) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 0b89149ec9..9bf2e49336 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -513,7 +513,7 @@ class _PGJSONB(JSONB): return None -class _Psycopg2Range(ranges.AbstractRangeImpl): +class _Psycopg2Range(ranges.AbstractSingleRangeImpl): _psycopg2_range_cls = "none" def bind_processor(self, dialect): diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 980f144935..b793ca49f1 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -15,8 +15,10 @@ from decimal import Decimal from typing import Any from typing import cast from typing import Generic +from typing import List from typing import Optional from typing import overload +from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING @@ -152,8 +154,8 @@ class Range(Generic[_T]): return not self.empty and self.upper is None @property - def __sa_type_engine__(self) -> AbstractRange[Range[_T]]: - return AbstractRange() + def __sa_type_engine__(self) -> AbstractSingleRange[_T]: + return AbstractSingleRange() def _contains_value(self, value: _T) -> bool: """Return True if this range contains the given value.""" @@ -708,15 +710,34 @@ class Range(Generic[_T]): return f"{b0}{l},{r}{b1}" -class AbstractRange(sqltypes.TypeEngine[Range[_T]]): - """ - Base for PostgreSQL RANGE types. +class MultiRange(List[Range[_T]]): + """Represents a multirange sequence. + + This list subclass is an utility to allow automatic type inference of + the proper multi-range SQL type depending on the single range values. + This is useful when operating on literal multi-ranges:: + + import sqlalchemy as sa + from sqlalchemy.dialects.postgresql import MultiRange, Range + + value = literal(MultiRange([Range(2, 4)])) + + select(tbl).where(tbl.c.value.op("@")(MultiRange([Range(-3, 7)]))) + + .. versionadded:: 2.0.26 .. seealso:: - `PostgreSQL range functions `_ + - :ref:`postgresql_multirange_list_use`. + """ - """ # noqa: E501 + @property + def __sa_type_engine__(self) -> AbstractMultiRange[_T]: + return AbstractMultiRange() + + +class AbstractRange(sqltypes.TypeEngine[_T]): + """Base class for single and multi Range SQL types.""" render_bind_cast = True @@ -742,7 +763,10 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): and also render as ``INT4RANGE`` in SQL and DDL. """ - if issubclass(cls, AbstractRangeImpl) and cls is not self.__class__: + if ( + issubclass(cls, (AbstractSingleRangeImpl, AbstractMultiRangeImpl)) + and cls is not self.__class__ + ): # two ways to do this are: 1. create a new type on the fly # or 2. have AbstractRangeImpl(visit_name) constructor and a # visit_abstract_range_impl() method in the PG compiler. @@ -761,21 +785,6 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): else: return super().adapt(cls) - def _resolve_for_literal(self, value: Any) -> Any: - spec = value.lower if value.lower is not None else value.upper - - if isinstance(spec, int): - return INT8RANGE() - elif isinstance(spec, (Decimal, float)): - return NUMRANGE() - elif isinstance(spec, datetime): - return TSRANGE() if not spec.tzinfo else TSTZRANGE() - elif isinstance(spec, date): - return DATERANGE() - else: - # empty Range, SQL datatype can't be determined here - return sqltypes.NULLTYPE - class comparator_factory(TypeEngine.Comparator[Range[Any]]): """Define comparison operations for range types.""" @@ -857,91 +866,164 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): return self.expr.operate(operators.mul, other) -class AbstractRangeImpl(AbstractRange[Range[_T]]): - """Marker for AbstractRange that will apply a subclass-specific +class AbstractSingleRange(AbstractRange[Range[_T]]): + """Base for PostgreSQL RANGE types. + + These are types that return a single :class:`_postgresql.Range` object. + + .. seealso:: + + `PostgreSQL range functions `_ + + """ # noqa: E501 + + __abstract__ = True + + def _resolve_for_literal(self, value: Range[Any]) -> Any: + spec = value.lower if value.lower is not None else value.upper + + if isinstance(spec, int): + # pg is unreasonably picky here: the query + # "select 1::INTEGER <@ '[1, 4)'::INT8RANGE" raises + # "operator does not exist: integer <@ int8range" as of pg 16 + if _is_int32(value): + return INT4RANGE() + else: + return INT8RANGE() + elif isinstance(spec, (Decimal, float)): + return NUMRANGE() + elif isinstance(spec, datetime): + return TSRANGE() if not spec.tzinfo else TSTZRANGE() + elif isinstance(spec, date): + return DATERANGE() + else: + # empty Range, SQL datatype can't be determined here + return sqltypes.NULLTYPE + + +class AbstractSingleRangeImpl(AbstractSingleRange[_T]): + """Marker for AbstractSingleRange that will apply a subclass-specific adaptation""" -class AbstractMultiRange(AbstractRange[Range[_T]]): - """base for PostgreSQL MULTIRANGE types""" +class AbstractMultiRange(AbstractRange[Sequence[Range[_T]]]): + """Base for PostgreSQL MULTIRANGE types. + + these are types that return a sequence of :class:`_postgresql.Range` + objects. + + """ __abstract__ = True + def _resolve_for_literal(self, value: Sequence[Range[Any]]) -> Any: + if not value: + # empty MultiRange, SQL datatype can't be determined here + return sqltypes.NULLTYPE + first = value[0] + spec = first.lower if first.lower is not None else first.upper -class AbstractMultiRangeImpl( - AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]] -): - """Marker for AbstractRange that will apply a subclass-specific + if isinstance(spec, int): + # pg is unreasonably picky here: the query + # "select 1::INTEGER <@ '{[1, 4),[6,19)}'::INT8MULTIRANGE" raises + # "operator does not exist: integer <@ int8multirange" as of pg 16 + if all(_is_int32(r) for r in value): + return INT4MULTIRANGE() + else: + return INT8MULTIRANGE() + elif isinstance(spec, (Decimal, float)): + return NUMMULTIRANGE() + elif isinstance(spec, datetime): + return TSMULTIRANGE() if not spec.tzinfo else TSTZMULTIRANGE() + elif isinstance(spec, date): + return DATEMULTIRANGE() + else: + # empty Range, SQL datatype can't be determined here + return sqltypes.NULLTYPE + + +class AbstractMultiRangeImpl(AbstractMultiRange[_T]): + """Marker for AbstractMultiRange that will apply a subclass-specific adaptation""" -class INT4RANGE(AbstractRange[Range[int]]): +class INT4RANGE(AbstractSingleRange[int]): """Represent the PostgreSQL INT4RANGE type.""" __visit_name__ = "INT4RANGE" -class INT8RANGE(AbstractRange[Range[int]]): +class INT8RANGE(AbstractSingleRange[int]): """Represent the PostgreSQL INT8RANGE type.""" __visit_name__ = "INT8RANGE" -class NUMRANGE(AbstractRange[Range[Decimal]]): +class NUMRANGE(AbstractSingleRange[Decimal]): """Represent the PostgreSQL NUMRANGE type.""" __visit_name__ = "NUMRANGE" -class DATERANGE(AbstractRange[Range[date]]): +class DATERANGE(AbstractSingleRange[date]): """Represent the PostgreSQL DATERANGE type.""" __visit_name__ = "DATERANGE" -class TSRANGE(AbstractRange[Range[datetime]]): +class TSRANGE(AbstractSingleRange[datetime]): """Represent the PostgreSQL TSRANGE type.""" __visit_name__ = "TSRANGE" -class TSTZRANGE(AbstractRange[Range[datetime]]): +class TSTZRANGE(AbstractSingleRange[datetime]): """Represent the PostgreSQL TSTZRANGE type.""" __visit_name__ = "TSTZRANGE" -class INT4MULTIRANGE(AbstractMultiRange[Range[int]]): +class INT4MULTIRANGE(AbstractMultiRange[int]): """Represent the PostgreSQL INT4MULTIRANGE type.""" __visit_name__ = "INT4MULTIRANGE" -class INT8MULTIRANGE(AbstractMultiRange[Range[int]]): +class INT8MULTIRANGE(AbstractMultiRange[int]): """Represent the PostgreSQL INT8MULTIRANGE type.""" __visit_name__ = "INT8MULTIRANGE" -class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]): +class NUMMULTIRANGE(AbstractMultiRange[Decimal]): """Represent the PostgreSQL NUMMULTIRANGE type.""" __visit_name__ = "NUMMULTIRANGE" -class DATEMULTIRANGE(AbstractMultiRange[Range[date]]): +class DATEMULTIRANGE(AbstractMultiRange[date]): """Represent the PostgreSQL DATEMULTIRANGE type.""" __visit_name__ = "DATEMULTIRANGE" -class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]): +class TSMULTIRANGE(AbstractMultiRange[datetime]): """Represent the PostgreSQL TSRANGE type.""" __visit_name__ = "TSMULTIRANGE" -class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]): +class TSTZMULTIRANGE(AbstractMultiRange[datetime]): """Represent the PostgreSQL TSTZRANGE type.""" __visit_name__ = "TSTZMULTIRANGE" + + +_max_int_32 = 2**31 - 1 +_min_int_32 = -(2**31) + + +def _is_int32(r: Range[int]) -> bool: + return (r.lower is None or _min_int_32 <= r.lower <= _max_int_32) and ( + r.upper is None or _min_int_32 <= r.upper <= _max_int_32 + ) diff --git a/setup.cfg b/setup.cfg index d51e4d854c..2a8a68132a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -181,7 +181,7 @@ mariadb_connector = mariadb+mariadbconnector://scott:tiger@127.0.0.1:3306/test mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes mssql_async = mssql+aioodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes pymssql = mssql+pymssql://scott:tiger^5HHH@mssql2017:1433/test -docker_mssql = mssql+pyodbc://scott:tiger^5HHH@127.0.0.1:1433/test?driver=ODBC+Driver+18+for+SQL+Server +docker_mssql = mssql+pyodbc://scott:tiger^5HHH@127.0.0.1:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes oracle = oracle+cx_oracle://scott:tiger@oracle18c/xe cxoracle = oracle+cx_oracle://scott:tiger@oracle18c/xe oracledb = oracle+oracledb://scott:tiger@oracle18c/xe diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 005e60eaa1..10144d63a6 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -52,6 +52,7 @@ from sqlalchemy.dialects.postgresql import TSQUERY from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.dialects.postgresql.base import PGDialect from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 +from sqlalchemy.dialects.postgresql.ranges import MultiRange from sqlalchemy.orm import aliased from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import Session @@ -2588,7 +2589,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(expr, expected) - def test_custom_object_hook(self): + def test_range_custom_object_hook(self): # See issue #8884 from datetime import date @@ -2608,6 +2609,30 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "WHERE usages.date <@ %(date_1)s::DATERANGE", ) + def test_multirange_custom_object_hook(self): + from datetime import date + + usages = table( + "usages", + column("id", Integer), + column("date", Date), + column("amount", Integer), + ) + period = MultiRange( + [ + Range(date(2022, 1, 1), (2023, 1, 1)), + Range(date(2024, 1, 1), (2025, 1, 1)), + ] + ) + stmt = select(func.sum(usages.c.amount)).where( + usages.c.date.op("<@")(period) + ) + self.assert_compile( + stmt, + "SELECT sum(usages.amount) AS sum_1 FROM usages " + "WHERE usages.date <@ %(date_1)s::DATEMULTIRANGE", + ) + def test_bitwise_xor(self): c1 = column("c1", Integer) c2 = column("c2", Integer) diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 2088436eeb..a5093c0bc9 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -73,6 +73,7 @@ from sqlalchemy.dialects.postgresql import TSMULTIRANGE from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE from sqlalchemy.dialects.postgresql import TSTZRANGE +from sqlalchemy.dialects.postgresql.ranges import MultiRange from sqlalchemy.exc import CompileError from sqlalchemy.exc import DBAPIError from sqlalchemy.orm import declarative_base @@ -92,6 +93,7 @@ from sqlalchemy.testing.assertions import AssertsExecutionResults from sqlalchemy.testing.assertions import ComparesTables from sqlalchemy.testing.assertions import eq_ from sqlalchemy.testing.assertions import is_ +from sqlalchemy.testing.assertions import ne_ from sqlalchemy.testing.assertsql import RegexSQL from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.testing.suite import test_types as suite @@ -3887,6 +3889,53 @@ class HStoreRoundTripTest(fixtures.TablesTest): eq_(s.query(Data.data, Data).all(), [(d.data, d)]) +class RangeMiscTests(fixtures.TestBase): + @testing.combinations( + (Range(2, 7), INT4RANGE), + (Range(-10, 7), INT4RANGE), + (Range(None, -7), INT4RANGE), + (Range(33, None), INT4RANGE), + (Range(-2147483648, 2147483647), INT4RANGE), + (Range(-2147483648 - 1, 2147483647), INT8RANGE), + (Range(-2147483648, 2147483647 + 1), INT8RANGE), + (Range(-2147483648 - 1, None), INT8RANGE), + (Range(None, 2147483647 + 1), INT8RANGE), + ) + def test_resolve_for_literal(self, obj, type_): + """This tests that the int4 / int8 version is selected correctly by + _resolve_for_literal.""" + lit = literal(obj) + eq_(type(lit.type), type_) + + @testing.combinations( + (Range(2, 7), INT4MULTIRANGE), + (Range(-10, 7), INT4MULTIRANGE), + (Range(None, -7), INT4MULTIRANGE), + (Range(33, None), INT4MULTIRANGE), + (Range(-2147483648, 2147483647), INT4MULTIRANGE), + (Range(-2147483648 - 1, 2147483647), INT8MULTIRANGE), + (Range(-2147483648, 2147483647 + 1), INT8MULTIRANGE), + (Range(-2147483648 - 1, None), INT8MULTIRANGE), + (Range(None, 2147483647 + 1), INT8MULTIRANGE), + ) + def test_resolve_for_literal_multi(self, obj, type_): + """This tests that the int4 / int8 version is selected correctly by + _resolve_for_literal.""" + list_ = MultiRange([Range(-1, 1), obj, Range(7, 100)]) + lit = literal(list_) + eq_(type(lit.type), type_) + + def test_multirange_sequence(self): + plain = [Range(-1, 1), Range(42, 43), Range(7, 100)] + mr = MultiRange(plain) + is_true(issubclass(MultiRange, list)) + is_true(isinstance(mr, list)) + eq_(mr, plain) + eq_(str(mr), str(plain)) + eq_(repr(mr), repr(plain)) + ne_(mr, plain[1:]) + + class _RangeTests: _col_type = None "The concrete range class these tests are for." @@ -4641,11 +4690,21 @@ class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest): Brought up in #8540. """ + # see also CompileTest::test_range_custom_object_hook data_obj = self._data_obj() stmt = select(literal(data_obj, type_=self._col_type)) round_trip = connection.scalar(stmt) eq_(round_trip, data_obj) + def test_auto_cast_back_to_type_without_type(self, connection): + """use _resolve_for_literal to cast""" + # see also CompileTest::test_range_custom_object_hook + data_obj = self._data_obj() + lit = literal(data_obj) + round_trip = connection.scalar(select(lit)) + eq_(round_trip, data_obj) + eq_(type(lit.type), self._col_type) + def test_actual_type(self): eq_(str(self._col_type()), self._col_str) @@ -5140,10 +5199,17 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): ) -class _MultiRangeTypeRoundTrip(fixtures.TablesTest): +class _MultiRangeTypeRoundTrip(fixtures.TablesTest, _RangeTests): __requires__ = ("multirange_types",) __backend__ = True + @testing.fixture(params=(True, False), ids=["multirange", "plain_list"]) + def data_obj(self, request): + if request.param: + return MultiRange(self._data_obj()) + else: + return list(self._data_obj()) + @classmethod def define_tables(cls, metadata): # no reason ranges shouldn't be primary keys, @@ -5155,7 +5221,7 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest): ) cls.col = table.c.range - def test_auto_cast_back_to_type(self, connection): + def test_auto_cast_back_to_type(self, connection, data_obj): """test that a straight pass of the range type without any context will send appropriate casting info so that the driver can round trip it. @@ -5170,11 +5236,29 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest): Brought up in #8540. """ - data_obj = self._data_obj() + # see also CompileTest::test_multirange_custom_object_hook stmt = select(literal(data_obj, type_=self._col_type)) round_trip = connection.scalar(stmt) eq_(round_trip, data_obj) + def test_auto_cast_back_to_type_without_type(self, connection): + """use _resolve_for_literal to cast""" + # see also CompileTest::test_multirange_custom_object_hook + data_obj = MultiRange(self._data_obj()) + lit = literal(data_obj) + round_trip = connection.scalar(select(lit)) + eq_(round_trip, data_obj) + eq_(type(lit.type), self._col_type) + + @testing.fails("no automatic adaptation of plain list") + def test_auto_cast_back_to_type_without_type_plain_list(self, connection): + """use _resolve_for_literal to cast""" + # see also CompileTest::test_multirange_custom_object_hook + data_obj = list(self._data_obj()) + lit = literal(data_obj) + r = connection.scalar(select(lit)) + eq_(type(r), list) + def test_actual_type(self): eq_(str(self._col_type()), self._col_str) @@ -5188,12 +5272,12 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest): def _assert_data(self, conn): data = conn.execute(select(self.tables.data_table.c.range)).fetchall() eq_(data, [(self._data_obj(),)]) + eq_(type(data[0][0]), MultiRange) - def test_textual_round_trip_w_dialect_type(self, connection): + def test_textual_round_trip_w_dialect_type(self, connection, data_obj): """test #8690""" data_table = self.tables.data_table - data_obj = self._data_obj() connection.execute( self.tables.data_table.insert(), {"range": data_obj} ) @@ -5206,9 +5290,9 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest): eq_(data_obj, v2) - def test_insert_obj(self, connection): + def test_insert_obj(self, connection, data_obj): connection.execute( - self.tables.data_table.insert(), {"range": self._data_obj()} + self.tables.data_table.insert(), {"range": data_obj} ) self._assert_data(connection) @@ -5229,6 +5313,7 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest): range_ = self.tables.data_table.c.range data = connection.execute(select(range_ + range_)).fetchall() eq_(data, [(self._data_obj(),)]) + eq_(type(data[0][0]), MultiRange) @testing.requires.psycopg_or_pg8000_compatibility def test_intersection_result_text(self, connection): @@ -5240,6 +5325,7 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest): range_ = self.tables.data_table.c.range data = connection.execute(select(range_ * range_)).fetchall() eq_(data, [(self._data_obj(),)]) + eq_(type(data[0][0]), MultiRange) @testing.requires.psycopg_or_pg8000_compatibility def test_difference_result_text(self, connection): @@ -5251,6 +5337,7 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest): range_ = self.tables.data_table.c.range data = connection.execute(select(range_ - range_)).fetchall() eq_(data, [([],)]) + eq_(type(data[0][0]), MultiRange) class _Int4MultiRangeTests: @@ -5261,11 +5348,7 @@ class _Int4MultiRangeTests: return "{[1,2), [3, 5), [9, 12)}" def _data_obj(self): - return [ - Range(1, 2), - Range(3, 5), - Range(9, 12), - ] + return [Range(1, 2), Range(3, 5), Range(9, 12)] class _Int8MultiRangeTests: @@ -5465,6 +5548,17 @@ class DateTimeTZRMultiangeRoundTripTest( pass +class MultiRangeSequenceTest(fixtures.TestBase): + def test_methods(self): + plain = [Range(1, 3), Range(5, 9)] + multi = MultiRange(plain) + is_true(isinstance(multi, list)) + eq_(multi, plain) + ne_(multi, plain[:1]) + eq_(str(multi), str(plain)) + eq_(repr(multi), repr(plain)) + + class JSONTest(AssertsCompiledSQL, fixtures.TestBase): __dialect__ = "postgresql" diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index 4567daa386..678d22b71f 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -12,14 +12,17 @@ from sqlalchemy import Text from sqlalchemy import UniqueConstraint from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import array +from sqlalchemy.dialects.postgresql import DATERANGE from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.dialects.postgresql import INT4RANGE +from sqlalchemy.dialects.postgresql import INT8MULTIRANGE from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column - # test #6402 c1 = Column(UUID()) @@ -77,3 +80,19 @@ insert(Test).on_conflict_do_nothing( ).on_conflict_do_update( unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22 ).excluded.foo.desc() + + +# EXPECTED_TYPE: Column[Range[int]] +reveal_type(Column(INT4RANGE())) +# EXPECTED_TYPE: Column[Range[datetime.date]] +reveal_type(Column("foo", DATERANGE())) +# EXPECTED_TYPE: Column[Sequence[Range[int]]] +reveal_type(Column(INT8MULTIRANGE())) +# EXPECTED_TYPE: Column[Sequence[Range[datetime.datetime]]] +reveal_type(Column("foo", TSTZMULTIRANGE())) + + +range_col_stmt = select(Column(INT4RANGE()), Column(INT8MULTIRANGE())) + +# EXPECTED_TYPE: Select[Tuple[Range[int], Sequence[Range[int]]]] +reveal_type(range_col_stmt) -- 2.47.2