]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement PG ranges/multiranges agnostically
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Aug 2022 14:27:59 +0000 (10:27 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Aug 2022 14:39:39 +0000 (10:39 -0400)
Ranges now work using a new Range object,
multiranges as lists of Range objects (this is what
asyncpg does.  not sure why psycopg has a "Multirange"
type).

psycopg, psycopg2, and asyncpg are currently supported.
It's not clear how to make ranges work with pg8000, likely
needs string conversion; this is straightforward with the
new archicture and can be added later.

Fixes: #8178
Change-Id: Iab8d8382873d5c14199adbe3f09fd0dc17e2b9f1

14 files changed:
doc/build/changelog/unreleased_20/7156.rst
doc/build/conf.py
doc/build/dialects/postgresql.rst
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/dialects/postgresql/ranges.py
lib/sqlalchemy/util/__init__.py
test/dialect/postgresql/test_dialect.py
test/dialect/postgresql/test_types.py
test/requirements.py

index 76a27ccc1fe0d14cbc1e5587e5e4cd40a31bc2f4..e54d7168998ee152252a1d0c13829ce87b1ac14c 100644 (file)
@@ -3,5 +3,15 @@
     :tickets: 7156
 
     Adds support for PostgreSQL MultiRange types, introduced in PostgreSQL 14.
-    Note that this feature currently only tested with `psycopg` and depends on
-    the `psycopg.types.range` extension module.
\ No newline at end of file
+    This integrates with both the ``psycopg.types.range`` extension module
+    as well as new the :class:`_postgresql.MultiRange` datatype introduced
+    in SQLAlchemy 2.0 for all PostgreSQL backends.
+
+
+.. change::
+    :tags: postgresql, feature
+
+    Implemented support for PostgreSQL ranges and multiranges for all
+    PostgreSQL backends.  To establish range and multirange types, use the
+    new :class:`_postgresql.Range` and :class:`_postgresql.MultiRange`
+    datatypes.
\ No newline at end of file
index 951e843b6896726df5656465a753ed7820a8dcd2..628dec47e47379290bd1350b9eb70f9d1a738768 100644 (file)
@@ -115,6 +115,7 @@ autodoc_class_signature = "separated"
 
 autodoc_default_options = {
     "exclude-members": "__new__",
+    "undoc-members": False,
 }
 
 # enable "annotation" indicator.  doesn't actually use this
index b3755c2cde07cdec4045b7e7be451e2c61184b85..a20c515e9581e959a2941fa977abf320323b81c1 100644 (file)
@@ -5,8 +5,8 @@ PostgreSQL
 
 .. automodule:: sqlalchemy.dialects.postgresql.base
 
-PostgreSQL Data Types and Custom SQL Constructs
-------------------------------------------------
+PostgreSQL Data Types
+---------------------
 
 As with all SQLAlchemy dialects, all UPPERCASE types that are known to be
 valid with PostgreSQL are importable from the top level dialect, whether
@@ -105,12 +105,6 @@ construction arguments, are as follows:
     :noindex:
 
 
-Range Types
-~~~~~~~~~~~
-
-The new range column types found in PostgreSQL 9.2 onwards are
-catered for by the following types:
-
 .. autoclass:: INT4RANGE
 
 
@@ -129,53 +123,6 @@ catered for by the following types:
 .. autoclass:: TSTZRANGE
 
 
-The types above get most of their functionality from the following
-mixin:
-
-.. autoclass:: sqlalchemy.dialects.postgresql.ranges.RangeOperators
-    :members:
-
-.. warning::
-
-  The range type DDL support should work with any PostgreSQL DBAPI
-  driver, however the data types returned may vary. If you are using
-  ``psycopg2``, it's recommended to upgrade to version 2.5 or later
-  before using these column types.
-
-When instantiating models that use these column types, you should pass
-whatever data type is expected by the DBAPI driver you're using for
-the column type. For ``psycopg2`` these are
-``psycopg2.extras.NumericRange``,
-``psycopg2.extras.DateRange``,
-``psycopg2.extras.DateTimeRange`` and
-``psycopg2.extras.DateTimeTZRange`` or the class you've
-registered with ``psycopg2.extras.register_range``.
-
-For example:
-
-.. code-block:: python
-
-  from psycopg2.extras import DateTimeRange
-  from sqlalchemy.dialects.postgresql import TSRANGE
-
-  class RoomBooking(Base):
-
-      __tablename__ = 'room_booking'
-
-      room = Column(Integer(), primary_key=True)
-      during = Column(TSRANGE())
-
-  booking = RoomBooking(
-      room=101,
-      during=DateTimeRange(datetime(2013, 3, 23), None)
-  )
-
-MultiRange Types
-~~~~~~~~~~~~~~~~
-
-The new MultiRange column types found in PostgreSQL 14 onwards are
-catered for by the following types:
-
 .. autoclass:: INT4MULTIRANGE
 
 
@@ -194,47 +141,6 @@ catered for by the following types:
 .. autoclass:: TSTZMULTIRANGE
 
 
-The types above get most of their functionality from the following
-mixin:
-
-.. autoclass:: sqlalchemy.dialects.postgresql.ranges.RangeOperators
-    :members:
-
-.. warning::
-
-  The multirange type DDL support should work with any PostgreSQL DBAPI
-  driver, however the data types returned may vary. The feature is
-  currently developed against the psycopg driver, and is known to
-  work with the range types specific to the `psycopg.types.range`
-  extension module.
-
-When instantiating models that use these column types, you should pass
-whatever data type is expected by the DBAPI driver you're using for
-the column type.
-
-For example:
-
-.. code-block:: python
-  # Note: Multirange type currently only tested against the psycopg
-  # driver, hence the use here.
-  from psycopg.types.range import Range
-  from pscyopg.types.multirange import Multirange
-  from sqlalchemy.dialects.postgresql import TSMULTIRANGE
-
-  class RoomBooking(Base):
-
-      __tablename__ = 'room_booking'
-
-      room = Column(Integer(), primary_key=True)
-      during = Column(TSMULTIRANGE())
-
-  booking = RoomBooking(
-      room=101,
-      during=Multirange([
-          Range(datetime(2013, 3, 23), datetime(2014, 3, 22)),
-          Range(datetime(2015, 1, 1), None)
-      ])
-
 
 PostgreSQL Constraint Types
 ---------------------------
index baafdb1811f2665110eec915017df76455016b67..104077a171c70f6d4a48fdd0d8bedc851a376ca7 100644 (file)
@@ -55,6 +55,7 @@ from .ranges import INT8MULTIRANGE
 from .ranges import INT8RANGE
 from .ranges import NUMMULTIRANGE
 from .ranges import NUMRANGE
+from .ranges import Range
 from .ranges import TSMULTIRANGE
 from .ranges import TSRANGE
 from .ranges import TSTZMULTIRANGE
@@ -135,6 +136,7 @@ __all__ = (
     "NamedType",
     "CreateEnumType",
     "ExcludeConstraint",
+    "Range",
     "aggregate_order_by",
     "array_agg",
     "insert",
index efd1dbe414ce7c5be792100fe7932acd891eec03..92341d2dac5ce67c86799b4693c7e92f42354dfc 100644 (file)
@@ -4,6 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: ignore-errors
+from __future__ import annotations
 
 import decimal
 
index d6385a5d6105a124d204af308d11f329dbe73cfc..38f8fddee66dba62e60ba58117d65b6c378f09a4 100644 (file)
@@ -119,14 +119,19 @@ client using this setting passed to :func:`_asyncio.create_async_engine`::
 
 """  # noqa
 
+from __future__ import annotations
+
 import collections
 import collections.abc as collections_abc
 import decimal
 import json as _py_json
 import re
 import time
+from typing import cast
+from typing import TYPE_CHECKING
 
 from . import json
+from . import ranges
 from .base import _DECIMAL_TYPES
 from .base import _FLOAT_TYPES
 from .base import _INT_TYPES
@@ -148,6 +153,9 @@ from ...util.concurrency import asyncio
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
+if TYPE_CHECKING:
+    from typing import Iterable
+
 
 class AsyncpgString(sqltypes.String):
     render_bind_cast = True
@@ -278,6 +286,91 @@ class AsyncpgCHAR(sqltypes.CHAR):
     render_bind_cast = True
 
 
+class _AsyncpgRange(ranges.AbstractRange):
+    def bind_processor(self, dialect):
+        Range = dialect.dbapi.asyncpg.Range
+
+        NoneType = type(None)
+
+        def to_range(value):
+            if not isinstance(value, (str, NoneType)):
+                value = Range(
+                    value.lower,
+                    value.upper,
+                    lower_inc=value.bounds[0] == "[",
+                    upper_inc=value.bounds[1] == "]",
+                    empty=value.empty,
+                )
+            return value
+
+        return to_range
+
+    def result_processor(self, dialect, coltype):
+        def to_range(value):
+            if value is not None:
+                empty = value.isempty
+                value = ranges.Range(
+                    value.lower,
+                    value.upper,
+                    bounds=f"{'[' if empty or value.lower_inc else '('}"  # type: ignore  # noqa: E501
+                    f"{']' if not empty and value.upper_inc else ')'}",
+                    empty=empty,
+                )
+            return value
+
+        return to_range
+
+
+class _AsyncpgMultiRange(ranges.AbstractMultiRange):
+    def bind_processor(self, dialect):
+        Range = dialect.dbapi.asyncpg.Range
+
+        NoneType = type(None)
+
+        def to_range(value):
+            if isinstance(value, (str, NoneType)):
+                return value
+
+            def to_range(value):
+                if not isinstance(value, (str, NoneType)):
+                    value = Range(
+                        value.lower,
+                        value.upper,
+                        lower_inc=value.bounds[0] == "[",
+                        upper_inc=value.bounds[1] == "]",
+                        empty=value.empty,
+                    )
+                return value
+
+            return [
+                to_range(element)
+                for element in cast("Iterable[ranges.Range]", value)
+            ]
+
+        return to_range
+
+    def result_processor(self, dialect, coltype):
+        def to_range_array(value):
+            def to_range(rvalue):
+                if rvalue is not None:
+                    empty = rvalue.isempty
+                    rvalue = ranges.Range(
+                        rvalue.lower,
+                        rvalue.upper,
+                        bounds=f"{'[' if empty or rvalue.lower_inc else '('}"  # type: ignore  # noqa: E501
+                        f"{']' if not empty and rvalue.upper_inc else ')'}",
+                        empty=empty,
+                    )
+                return rvalue
+
+            if value is not None:
+                value = [to_range(elem) for elem in value]
+
+            return value
+
+        return to_range_array
+
+
 class PGExecutionContext_asyncpg(PGExecutionContext):
     def handle_dbapi_exception(self, e):
         if isinstance(
@@ -828,6 +921,8 @@ class PGDialect_asyncpg(PGDialect):
             OID: AsyncpgOID,
             REGCLASS: AsyncpgREGCLASS,
             sqltypes.CHAR: AsyncpgCHAR,
+            ranges.AbstractRange: _AsyncpgRange,
+            ranges.AbstractMultiRange: _AsyncpgMultiRange,
         },
     )
     is_async = True
index efb4dd547f0bbc67589afd926cc491d4f16a67d9..2ee679e8eaee973ed0a24e1244996e11e6a94b9a 100644 (file)
@@ -1445,6 +1445,157 @@ E.g.::
         Column('data', CastingArray(JSONB))
     )
 
+Range and Multirange Types
+--------------------------
+
+PostgreSQL range and multirange types are supported for the psycopg2,
+psycopg, and asyncpg dialects.
+
+Data values being passed to the database may be passed as string
+values or by using the :class:`_postgresql.Range` data object.
+
+.. versionadded:: 2.0  Added the backend-agnostic :class:`_postgresql.Range`
+   object used to indicate ranges.  The ``psycopg2``-specific range classes
+   are no longer exposed and are only used internally by that particular
+   dialect.
+
+E.g. an example of a fully typed model using the
+:class:`_postgresql.TSRANGE` datatype::
+
+  from datetime import datetime
+
+  from sqlalchemy.dialects.postgresql import Range
+  from sqlalchemy.dialects.postgresql import TSRANGE
+  from sqlalchemy.orm import DeclarativeBase
+  from sqlalchemy.orm import Mapped
+  from sqlalchemy.orm import mapped_column
+
+  class Base(DeclarativeBase):
+      pass
+
+  class RoomBooking(Base):
+
+      __tablename__ = 'room_booking'
+
+      id: Mapped[int] = mapped_column(primary_key=True)
+      room: Mapped[str]
+      during: Mapped[Range[datetime]] = mapped_column(TSRANGE)
+
+To represent data for the ``during`` column above, the :class:`_postgresql.Range`
+type is a simple dataclass that will represent the bounds of the range.
+Below illustrates an INSERT of a row into the above ``room_booking`` table::
+
+  from sqlalchemy import create_engine
+  from sqlalchemy.orm import Session
+
+  engine = create_engine("postgresql+psycopg://scott:tiger@pg14/dbname")
+
+  Base.metadata.create_all(engine)
+
+  with Session(engine) as session:
+      booking = RoomBooking(
+          room="101",
+          during=Range(datetime(2013, 3, 23), datetime(2013, 3, 25))
+      )
+      session.add(booking)
+      session.commit()
+
+Selecting from any range column will also return :class:`_postgresql.Range`
+objects as indicated::
+
+  from sqlalchemy import select
+
+  with Session(engine) as session:
+      for row in session.execute(select(RoomBooking.during)):
+          print(row)
+
+The available range datatypes are as follows:
+
+* :class:`_postgresql.INT4RANGE`
+* :class:`_postgresql.INT8RANGE`
+* :class:`_postgresql.NUMRANGE`
+* :class:`_postgresql.DATERANGE`
+* :class:`_postgresql.TSRANGE`
+* :class:`_postgresql.TSTZRANGE`
+
+.. autoclass:: sqlalchemy.dialects.postgresql.Range
+
+Multiranges
+^^^^^^^^^^^
+
+Multiranges are supported by PostgreSQL 14 and above.  SQLAlchemy's
+multirange datatypes deal in lists of :class:`_postgresql.Range` types.
+
+.. versionadded:: 2.0 Added support for MULTIRANGE datatypes.   In contrast
+   to the ``psycopg`` multirange feature, SQLAlchemy's adaptation represents
+   a multirange datatype as a list of :class:`_postgresql.Range` objects.
+
+The example below illustrates use of the :class:`_postgresql.TSMULTIRANGE`
+datatype::
+
+    from datetime import datetime
+    from typing import List
+
+    from sqlalchemy.dialects.postgresql import Range
+    from sqlalchemy.dialects.postgresql import TSMULTIRANGE
+    from sqlalchemy.orm import DeclarativeBase
+    from sqlalchemy.orm import Mapped
+    from sqlalchemy.orm import mapped_column
+
+    class Base(DeclarativeBase):
+        pass
+
+    class EventCalendar(Base):
+
+        __tablename__ = 'event_calendar'
+
+        id: Mapped[int] = mapped_column(primary_key=True)
+        event_name: Mapped[str]
+        in_session_periods: Mapped[List[Range[datetime]]] = mapped_column(TSMULTIRANGE)
+
+Illustrating insertion and selecting of a record::
+
+    from sqlalchemy import create_engine
+    from sqlalchemy import select
+    from sqlalchemy.orm import Session
+
+    engine = create_engine("postgresql+psycopg://scott:tiger@pg14/test")
+
+    Base.metadata.create_all(engine)
+
+    with Session(engine) as session:
+        calendar = EventCalendar(
+            event_name="SQLAlchemy Tutorial Sessions",
+            in_session_periods= [
+                Range(datetime(2013, 3, 23), datetime(2013, 3, 25)),
+                Range(datetime(2013, 4, 12), datetime(2013, 4, 15)),
+                Range(datetime(2013, 5, 9), datetime(2013, 5, 12)),
+            ]
+        )
+        session.add(calendar)
+        session.commit()
+
+        for multirange in session.scalars(select(EventCalendar.in_session_periods)):
+            for range_ in multirange:
+                print(f"Start: {range_.lower}  End: {range_.upper}")
+
+.. note:: In the above example, the list of :class:`_postgresql.Range` types
+   as handled by the ORM will not automatically detect in-place changes to
+   a particular list value; to update list values with the ORM, either re-assign
+   a new list to the attribute, or use the :class:`.MutableList`
+   type modifier.  See the section :ref:`mutable_toplevel` for background.
+
+
+The available multirange datatypes are as follows:
+
+* :class:`_postgresql.INT4MULTIRANGE`
+* :class:`_postgresql.INT8MULTIRANGE`
+* :class:`_postgresql.NUMMULTIRANGE`
+* :class:`_postgresql.DATEMULTIRANGE`
+* :class:`_postgresql.TSMULTIRANGE`
+* :class:`_postgresql.TSTZMULTIRANGE`
+
+
 
 """  # noqa: E501
 
index 414976a6299347e57d50510ab274739666c205e0..633357a74062be377adb09ea9e039b410b74528c 100644 (file)
@@ -57,9 +57,14 @@ release of SQLAlchemy 2.0, however.
     Further documentation is available there.
 
 """  # noqa
+from __future__ import annotations
+
 import logging
 import re
+from typing import cast
+from typing import TYPE_CHECKING
 
+from . import ranges
 from ._psycopg_common import _PGDialect_common_psycopg
 from ._psycopg_common import _PGExecutionContext_common_psycopg
 from .base import INTERVAL
@@ -75,6 +80,9 @@ from ...sql import sqltypes
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
+if TYPE_CHECKING:
+    from typing import Iterable
+
 logger = logging.getLogger("sqlalchemy.dialects.postgresql")
 
 
@@ -154,6 +162,78 @@ class _PGBoolean(sqltypes.Boolean):
     render_bind_cast = True
 
 
+class _PsycopgRange(ranges.AbstractRange):
+    def bind_processor(self, dialect):
+        Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
+
+        NoneType = type(None)
+
+        def to_range(value):
+            if not isinstance(value, (str, NoneType)):
+                value = Range(
+                    value.lower, value.upper, value.bounds, value.empty
+                )
+            return value
+
+        return to_range
+
+    def result_processor(self, dialect, coltype):
+        def to_range(value):
+            if value is not None:
+                value = ranges.Range(
+                    value._lower,
+                    value._upper,
+                    bounds=value._bounds if value._bounds else "[)",
+                    empty=not value._bounds,
+                )
+            return value
+
+        return to_range
+
+
+class _PsycopgMultiRange(ranges.AbstractMultiRange):
+    def bind_processor(self, dialect):
+        Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
+        Multirange = cast(PGDialect_psycopg, dialect)._psycopg_Multirange
+
+        NoneType = type(None)
+
+        def to_range(value):
+            if isinstance(value, (str, NoneType)):
+                return value
+
+            return Multirange(
+                [
+                    Range(
+                        element.lower,
+                        element.upper,
+                        element.bounds,
+                        element.empty,
+                    )
+                    for element in cast("Iterable[ranges.Range]", value)
+                ]
+            )
+
+        return to_range
+
+    def result_processor(self, dialect, coltype):
+        def to_range(value):
+            if value is not None:
+                value = [
+                    ranges.Range(
+                        elem._lower,
+                        elem._upper,
+                        bounds=elem._bounds if elem._bounds else "[)",
+                        empty=not elem._bounds,
+                    )
+                    for elem in value
+                ]
+
+            return value
+
+        return to_range
+
+
 class PGExecutionContext_psycopg(_PGExecutionContext_common_psycopg):
     pass
 
@@ -204,6 +284,8 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
             sqltypes.Integer: _PGInteger,
             sqltypes.SmallInteger: _PGSmallInteger,
             sqltypes.BigInteger: _PGBigInteger,
+            ranges.AbstractRange: _PsycopgRange,
+            ranges.AbstractMultiRange: _PsycopgMultiRange,
         },
     )
 
@@ -314,6 +396,18 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
 
         return TransactionStatus
 
+    @util.memoized_property
+    def _psycopg_Range(self):
+        from psycopg.types.range import Range
+
+        return Range
+
+    @util.memoized_property
+    def _psycopg_Multirange(self):
+        from psycopg.types.multirange import Multirange
+
+        return Multirange
+
     def _do_isolation_level(self, connection, autocommit, isolation_level):
         connection.autocommit = autocommit
         connection.isolation_level = isolation_level
index 6f78dafdd0cf239d13e852703e7f938b1c65c4bf..5dcd449cab8cbcbe61529f8b68cd7b4c0527f028 100644 (file)
@@ -474,10 +474,14 @@ place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2``
 which may be more performant.
 
 """  # noqa
+from __future__ import annotations
+
 import collections.abc as collections_abc
 import logging
 import re
+from typing import cast
 
+from . import ranges
 from ._psycopg_common import _PGDialect_common_psycopg
 from ._psycopg_common import _PGExecutionContext_common_psycopg
 from .base import PGCompiler
@@ -490,7 +494,6 @@ from ...engine import cursor as _cursor
 from ...util import FastIntFlag
 from ...util import parse_user_argument_for_enum
 
-
 logger = logging.getLogger("sqlalchemy.dialects.postgresql")
 
 
@@ -504,6 +507,56 @@ class _PGJSONB(JSONB):
         return None
 
 
+class _Psycopg2Range(ranges.AbstractRange):
+    _psycopg2_range_cls = "none"
+
+    def bind_processor(self, dialect):
+        Range = getattr(
+            cast(PGDialect_psycopg2, dialect)._psycopg2_extras,
+            self._psycopg2_range_cls,
+        )
+
+        NoneType = type(None)
+
+        def to_range(value):
+            if not isinstance(value, (str, NoneType)):
+                value = Range(
+                    value.lower, value.upper, value.bounds, value.empty
+                )
+            return value
+
+        return to_range
+
+    def result_processor(self, dialect, coltype):
+        def to_range(value):
+            if value is not None:
+                value = ranges.Range(
+                    value._lower,
+                    value._upper,
+                    bounds=value._bounds if value._bounds else "[)",
+                    empty=not value._bounds,
+                )
+            return value
+
+        return to_range
+
+
+class _Psycopg2NumericRange(_Psycopg2Range):
+    _psycopg2_range_cls = "NumericRange"
+
+
+class _Psycopg2DateRange(_Psycopg2Range):
+    _psycopg2_range_cls = "DateRange"
+
+
+class _Psycopg2DateTimeRange(_Psycopg2Range):
+    _psycopg2_range_cls = "DateTimeRange"
+
+
+class _Psycopg2DateTimeTZRange(_Psycopg2Range):
+    _psycopg2_range_cls = "DateTimeTZRange"
+
+
 class PGExecutionContext_psycopg2(_PGExecutionContext_common_psycopg):
     _psycopg2_fetched_rows = None
 
@@ -589,6 +642,12 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg):
             JSON: _PGJSON,
             sqltypes.JSON: _PGJSON,
             JSONB: _PGJSONB,
+            ranges.INT4RANGE: _Psycopg2NumericRange,
+            ranges.INT8RANGE: _Psycopg2NumericRange,
+            ranges.NUMRANGE: _Psycopg2NumericRange,
+            ranges.DATERANGE: _Psycopg2DateRange,
+            ranges.TSRANGE: _Psycopg2DateTimeRange,
+            ranges.TSTZRANGE: _Psycopg2DateTimeTZRange,
         },
     )
 
index 4f010abf1351702d3cee715cd82023fbd853ab06..edbe165d987aa4031d06c9380cfd20ff94443bf4 100644 (file)
@@ -5,28 +5,91 @@
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: ignore-errors
 
+from __future__ import annotations
+
+import dataclasses
+from typing import Any
+from typing import Generic
+from typing import Optional
+from typing import TypeVar
 
 from ... import types as sqltypes
+from ...util import py310
+from ...util.typing import Literal
+
+_T = TypeVar("_T", bound=Any)
+
+
+if py310:
+    dc_slots = {"slots": True}
+    dc_kwonly = {"kw_only": True}
+else:
+    dc_slots = {}
+    dc_kwonly = {}
+
 
+@dataclasses.dataclass(frozen=True, **dc_slots)
+class Range(Generic[_T]):
+    """Represent a PostgreSQL range.
 
-__all__ = ("INT4RANGE", "INT8RANGE", "NUMRANGE")
+    E.g.::
 
+        r = Range(10, 50, bounds="()")
+
+    The calling style is similar to that of psycopg and psycopg2, in part
+    to allow easier migration from previous SQLAlchemy versions that used
+    these objects directly.
+
+    :param lower: Lower bound value, or None
+    :param upper: Upper bound value, or None
+    :param bounds: keyword-only, optional string value that is one of
+     ``"()"``, ``"[)"``, ``"(]"``, ``"[]"``.  Defaults to ``"[)"``.
+    :param empty: keyword-only, optional bool indicating this is an "empty"
+     range
+
+    .. versionadded:: 2.0
 
-class RangeOperators:
     """
-    This mixin provides functionality for the Range Operators
-    listed in the Range Operators table of the `PostgreSQL documentation`__
-    for Range Functions and Operators. It is used by all the range types
-    provided in the ``postgres`` dialect and can likely be used for
-    any range types you create yourself.
 
-    __ https://www.postgresql.org/docs/current/static/functions-range.html
+    lower: Optional[_T] = None
+    """the lower bound"""
+
+    upper: Optional[_T] = None
+    """the upper bound"""
 
-    No extra support is provided for the Range Functions listed in the Range
-    Functions table of the PostgreSQL documentation. For these, the normal
-    :func:`~sqlalchemy.sql.expression.func` object should be used.
+    bounds: Literal["()", "[)", "(]", "[]"] = dataclasses.field(
+        default="[)", **dc_kwonly
+    )
+    empty: bool = dataclasses.field(default=False, **dc_kwonly)
 
+    if not py310:
+
+        def __init__(
+            self, lower=None, upper=None, *, bounds="[)", empty=False
+        ):
+            # no __slots__ either so we can update dict
+            self.__dict__.update(
+                {
+                    "lower": lower,
+                    "upper": upper,
+                    "bounds": bounds,
+                    "empty": empty,
+                }
+            )
+
+    def __bool__(self) -> bool:
+        return self.empty
+
+
+class AbstractRange(sqltypes.TypeEngine):
     """
+    Base for PostgreSQL RANGE types.
+
+    .. seealso::
+
+        `PostgreSQL range functions <https://www.postgresql.org/docs/current/static/functions-range.html>`_
+
+    """  # noqa: E501
 
     class comparator_factory(sqltypes.Concatenable.Comparator):
         """Define comparison operations for range types."""
@@ -34,9 +97,7 @@ class RangeOperators:
         def __ne__(self, other):
             "Boolean expression. Returns true if two ranges are not equal"
             if other is None:
-                return super(RangeOperators.comparator_factory, self).__ne__(
-                    other
-                )
+                return super().__ne__(other)
             else:
                 return self.expr.op("<>", is_comparison=True)(other)
 
@@ -104,73 +165,77 @@ class RangeOperators:
             return self.expr.op("+")(other)
 
 
-class INT4RANGE(RangeOperators, sqltypes.TypeEngine):
+class AbstractMultiRange(AbstractRange):
+    """base for PostgreSQL MULTIRANGE types"""
+
+
+class INT4RANGE(AbstractRange):
     """Represent the PostgreSQL INT4RANGE type."""
 
     __visit_name__ = "INT4RANGE"
 
 
-class INT8RANGE(RangeOperators, sqltypes.TypeEngine):
+class INT8RANGE(AbstractRange):
     """Represent the PostgreSQL INT8RANGE type."""
 
     __visit_name__ = "INT8RANGE"
 
 
-class NUMRANGE(RangeOperators, sqltypes.TypeEngine):
+class NUMRANGE(AbstractRange):
     """Represent the PostgreSQL NUMRANGE type."""
 
     __visit_name__ = "NUMRANGE"
 
 
-class DATERANGE(RangeOperators, sqltypes.TypeEngine):
+class DATERANGE(AbstractRange):
     """Represent the PostgreSQL DATERANGE type."""
 
     __visit_name__ = "DATERANGE"
 
 
-class TSRANGE(RangeOperators, sqltypes.TypeEngine):
+class TSRANGE(AbstractRange):
     """Represent the PostgreSQL TSRANGE type."""
 
     __visit_name__ = "TSRANGE"
 
 
-class TSTZRANGE(RangeOperators, sqltypes.TypeEngine):
+class TSTZRANGE(AbstractRange):
     """Represent the PostgreSQL TSTZRANGE type."""
 
     __visit_name__ = "TSTZRANGE"
 
 
-class INT4MULTIRANGE(RangeOperators, sqltypes.TypeEngine):
+class INT4MULTIRANGE(AbstractMultiRange):
     """Represent the PostgreSQL INT4MULTIRANGE type."""
 
     __visit_name__ = "INT4MULTIRANGE"
 
 
-class INT8MULTIRANGE(RangeOperators, sqltypes.TypeEngine):
+class INT8MULTIRANGE(AbstractMultiRange):
     """Represent the PostgreSQL INT8MULTIRANGE type."""
 
     __visit_name__ = "INT8MULTIRANGE"
 
 
-class NUMMULTIRANGE(RangeOperators, sqltypes.TypeEngine):
+class NUMMULTIRANGE(AbstractMultiRange):
     """Represent the PostgreSQL NUMMULTIRANGE type."""
 
     __visit_name__ = "NUMMULTIRANGE"
 
 
-class DATEMULTIRANGE(RangeOperators, sqltypes.TypeEngine):
+class DATEMULTIRANGE(AbstractMultiRange):
     """Represent the PostgreSQL DATEMULTIRANGE type."""
 
     __visit_name__ = "DATEMULTIRANGE"
 
 
-class TSMULTIRANGE(RangeOperators, sqltypes.TypeEngine):
+class TSMULTIRANGE(AbstractMultiRange):
     """Represent the PostgreSQL TSRANGE type."""
 
     __visit_name__ = "TSMULTIRANGE"
 
 
-class TSTZMULTIRANGE(RangeOperators, sqltypes.TypeEngine):
+class TSTZMULTIRANGE(AbstractMultiRange):
     """Represent the PostgreSQL TSTZRANGE type."""
 
     __visit_name__ = "TSTZMULTIRANGE"
index 756b93bb4783960139d2d2efc14fb130401a4c45..c3dfdadc4e96ef34a71867e43184afdfeb15ca16 100644 (file)
@@ -59,6 +59,7 @@ from .compat import has_refcount_gc as has_refcount_gc
 from .compat import inspect_getfullargspec as inspect_getfullargspec
 from .compat import local_dataclass_fields as local_dataclass_fields
 from .compat import osx as osx
+from .compat import py310 as py310
 from .compat import py311 as py311
 from .compat import py38 as py38
 from .compat import py39 as py39
index 9cbb0bca7a73a32d12314911389ff600384fec6d..fdb114d57c3b48ffc991e1e38dd15d8ea272189d 100644 (file)
@@ -1,4 +1,5 @@
 # coding: utf-8
+import dataclasses
 import datetime
 import itertools
 import logging
@@ -35,6 +36,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
 from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.dialects.postgresql import psycopg as psycopg_dialect
 from sqlalchemy.dialects.postgresql import psycopg2 as psycopg2_dialect
+from sqlalchemy.dialects.postgresql import Range
 from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_BATCH
 from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_PLAIN
 from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_VALUES
@@ -55,6 +57,7 @@ from sqlalchemy.testing.assertions import AssertsCompiledSQL
 from sqlalchemy.testing.assertions import AssertsExecutionResults
 from sqlalchemy.testing.assertions import eq_
 from sqlalchemy.testing.assertions import eq_regex
+from sqlalchemy.testing.assertions import expect_raises
 from sqlalchemy.testing.assertions import ne_
 
 if True:
@@ -66,6 +69,29 @@ if True:
 class DialectTest(fixtures.TestBase):
     """python-side dialect tests."""
 
+    def test_range_constructor(self):
+        """test kwonly argments in the range constructor, as we had
+        to do dataclasses backwards compat operations"""
+
+        r1 = Range(None, 5)
+        eq_(dataclasses.astuple(r1), (None, 5, "[)", False))
+
+        r1 = Range(10, 5, bounds="()")
+        eq_(dataclasses.astuple(r1), (10, 5, "()", False))
+
+        with expect_raises(TypeError):
+            Range(10, 5, "()")  # type: ignore
+
+        with expect_raises(TypeError):
+            Range(None, None, "()", True)  # type: ignore
+
+    def test_range_frozen(self):
+        r1 = Range(None, 5)
+        eq_(dataclasses.astuple(r1), (None, 5, "[)", False))
+
+        with expect_raises(dataclasses.FrozenInstanceError):
+            r1.lower = 8  # type: ignore
+
     def test_version_parsing(self):
         def mock_conn(res):
             return mock.Mock(
index f774300e68e5bf85911701bcb050a2fa77ccad63..b4c19238d324c0a8a27154d381f6111a48e11dd7 100644 (file)
@@ -1,5 +1,4 @@
 # coding: utf-8
-from collections import defaultdict
 import datetime
 import decimal
 from enum import Enum as _PY_Enum
@@ -53,6 +52,7 @@ from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.dialects.postgresql import NamedType
 from sqlalchemy.dialects.postgresql import NUMMULTIRANGE
 from sqlalchemy.dialects.postgresql import NUMRANGE
+from sqlalchemy.dialects.postgresql import Range
 from sqlalchemy.dialects.postgresql import TSMULTIRANGE
 from sqlalchemy.dialects.postgresql import TSRANGE
 from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE
@@ -3708,28 +3708,9 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
 
 
 class _RangeTypeRoundTrip(fixtures.TablesTest):
-    __requires__ = "range_types", "any_psycopg_compatibility"
+    __requires__ = ("range_types",)
     __backend__ = True
 
-    def extras(self):
-        # done this way so we don't get ImportErrors with
-        # older psycopg2 versions.
-        if testing.against("postgresql+psycopg2cffi"):
-            from psycopg2cffi import extras
-        elif testing.against("postgresql+psycopg2"):
-            from psycopg2 import extras
-        elif testing.against("postgresql+psycopg"):
-            from psycopg.types.range import Range
-
-            class psycopg_extras:
-                def __getattr__(self, _):
-                    return Range
-
-            extras = psycopg_extras()
-        else:
-            assert False, "Unknown dialect"
-        return extras
-
     @classmethod
     def define_tables(cls, metadata):
         # no reason ranges shouldn't be primary keys,
@@ -3761,13 +3742,25 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
         )
         self._assert_data(connection)
 
+    @testing.requires.any_psycopg_compatibility
     def test_insert_text(self, connection):
         connection.execute(
             self.tables.data_table.insert(), {"range": self._data_str()}
         )
         self._assert_data(connection)
 
-    def test_union_result(self, connection):
+    def test_union_result_obj(self, connection):
+        # insert
+        connection.execute(
+            self.tables.data_table.insert(), {"range": self._data_obj()}
+        )
+        # select
+        range_ = self.tables.data_table.c.range
+        data = connection.execute(select(range_ + range_)).fetchall()
+        eq_(data, [(self._data_obj(),)])
+
+    @testing.requires.any_psycopg_compatibility
+    def test_union_result_text(self, connection):
         # insert
         connection.execute(
             self.tables.data_table.insert(), {"range": self._data_str()}
@@ -3777,7 +3770,18 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
         data = connection.execute(select(range_ + range_)).fetchall()
         eq_(data, [(self._data_obj(),)])
 
-    def test_intersection_result(self, connection):
+    def test_intersection_result_obj(self, connection):
+        # insert
+        connection.execute(
+            self.tables.data_table.insert(), {"range": self._data_obj()}
+        )
+        # select
+        range_ = self.tables.data_table.c.range
+        data = connection.execute(select(range_ * range_)).fetchall()
+        eq_(data, [(self._data_obj(),)])
+
+    @testing.requires.any_psycopg_compatibility
+    def test_intersection_result_text(self, connection):
         # insert
         connection.execute(
             self.tables.data_table.insert(), {"range": self._data_str()}
@@ -3787,7 +3791,18 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
         data = connection.execute(select(range_ * range_)).fetchall()
         eq_(data, [(self._data_obj(),)])
 
-    def test_difference_result(self, connection):
+    def test_difference_result_obj(self, connection):
+        # insert
+        connection.execute(
+            self.tables.data_table.insert(), {"range": self._data_obj()}
+        )
+        # select
+        range_ = self.tables.data_table.c.range
+        data = connection.execute(select(range_ - range_)).fetchall()
+        eq_(data, [(self._data_obj().__class__(empty=True),)])
+
+    @testing.requires.any_psycopg_compatibility
+    def test_difference_result_text(self, connection):
         # insert
         connection.execute(
             self.tables.data_table.insert(), {"range": self._data_str()}
@@ -3807,7 +3822,7 @@ class _Int4RangeTests:
         return "[1,2)"
 
     def _data_obj(self):
-        return self.extras().NumericRange(1, 2)
+        return Range(1, 2)
 
 
 class _Int8RangeTests:
@@ -3819,9 +3834,7 @@ class _Int8RangeTests:
         return "[9223372036854775806,9223372036854775807)"
 
     def _data_obj(self):
-        return self.extras().NumericRange(
-            9223372036854775806, 9223372036854775807
-        )
+        return Range(9223372036854775806, 9223372036854775807)
 
 
 class _NumRangeTests:
@@ -3833,9 +3846,7 @@ class _NumRangeTests:
         return "[1.0,2.0)"
 
     def _data_obj(self):
-        return self.extras().NumericRange(
-            decimal.Decimal("1.0"), decimal.Decimal("2.0")
-        )
+        return Range(decimal.Decimal("1.0"), decimal.Decimal("2.0"))
 
 
 class _DateRangeTests:
@@ -3847,9 +3858,7 @@ class _DateRangeTests:
         return "[2013-03-23,2013-03-24)"
 
     def _data_obj(self):
-        return self.extras().DateRange(
-            datetime.date(2013, 3, 23), datetime.date(2013, 3, 24)
-        )
+        return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 24))
 
 
 class _DateTimeRangeTests:
@@ -3861,7 +3870,7 @@ class _DateTimeRangeTests:
         return "[2013-03-23 14:30,2013-03-23 23:30)"
 
     def _data_obj(self):
-        return self.extras().DateTimeRange(
+        return Range(
             datetime.datetime(2013, 3, 23, 14, 30),
             datetime.datetime(2013, 3, 23, 23, 30),
         )
@@ -3888,7 +3897,7 @@ class _DateTimeTZRangeTests:
         return "[%s,%s)" % self.tstzs()
 
     def _data_obj(self):
-        return self.extras().DateTimeTZRange(*self.tstzs())
+        return Range(*self.tstzs())
 
 
 class Int4RangeCompilationTest(_Int4RangeTests, _RangeTypeCompilation):
@@ -4104,30 +4113,9 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
 
 
 class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
-    __requires__ = "range_types", "psycopg_only_compatibility"
+    __requires__ = ("multirange_types",)
     __backend__ = True
 
-    def extras(self):
-        # done this way so we don't get ImportErrors with
-        # older psycopg2 versions.
-        if testing.against("postgresql+psycopg"):
-            from psycopg.types.range import Range
-            from psycopg.types.multirange import Multirange
-
-            class psycopg_extras:
-                def __init__(self):
-                    self.data = defaultdict(
-                        lambda: Range, Multirange=Multirange
-                    )
-
-                def __getattr__(self, name):
-                    return self.data[name]
-
-            extras = psycopg_extras()
-        else:
-            assert False, "Unsupported MultiRange Dialect"
-        return extras
-
     @classmethod
     def define_tables(cls, metadata):
         # no reason ranges shouldn't be primary keys,
@@ -4159,13 +4147,15 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
         )
         self._assert_data(connection)
 
+    @testing.requires.any_psycopg_compatibility
     def test_insert_text(self, connection):
         connection.execute(
             self.tables.data_table.insert(), {"range": self._data_str()}
         )
         self._assert_data(connection)
 
-    def test_union_result(self, connection):
+    @testing.requires.any_psycopg_compatibility
+    def test_union_result_text(self, connection):
         # insert
         connection.execute(
             self.tables.data_table.insert(), {"range": self._data_str()}
@@ -4175,7 +4165,8 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
         data = connection.execute(select(range_ + range_)).fetchall()
         eq_(data, [(self._data_obj(),)])
 
-    def test_intersection_result(self, connection):
+    @testing.requires.any_psycopg_compatibility
+    def test_intersection_result_text(self, connection):
         # insert
         connection.execute(
             self.tables.data_table.insert(), {"range": self._data_str()}
@@ -4185,7 +4176,8 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
         data = connection.execute(select(range_ * range_)).fetchall()
         eq_(data, [(self._data_obj(),)])
 
-    def test_difference_result(self, connection):
+    @testing.requires.any_psycopg_compatibility
+    def test_difference_result_text(self, connection):
         # insert
         connection.execute(
             self.tables.data_table.insert(), {"range": self._data_str()}
@@ -4193,7 +4185,7 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
         # select
         range_ = self.tables.data_table.c.range
         data = connection.execute(select(range_ - range_)).fetchall()
-        eq_(data, [(self.extras().Multirange(),)])
+        eq_(data, [([],)])
 
 
 class _Int4MultiRangeTests:
@@ -4205,13 +4197,11 @@ class _Int4MultiRangeTests:
         return "{[1,2), [3, 5), [9, 12)}"
 
     def _data_obj(self):
-        return self.extras().Multirange(
-            [
-                self.extras().Range(1, 2),
-                self.extras().Range(3, 5),
-                self.extras().Range(9, 12),
-            ]
-        )
+        return [
+            Range(1, 2),
+            Range(3, 5),
+            Range(9, 12),
+        ]
 
 
 class _Int8MultiRangeTests:
@@ -4226,12 +4216,10 @@ class _Int8MultiRangeTests:
         )
 
     def _data_obj(self):
-        return self.extras().Multirange(
-            [
-                self.extras().Range(9223372036854775801, 9223372036854775803),
-                self.extras().Range(9223372036854775805, 9223372036854775807),
-            ]
-        )
+        return [
+            Range(9223372036854775801, 9223372036854775803),
+            Range(9223372036854775805, 9223372036854775807),
+        ]
 
 
 class _NumMultiRangeTests:
@@ -4243,19 +4231,11 @@ class _NumMultiRangeTests:
         return "{[1.0,2.0), [3.0, 5.0), [9.0, 12.0)}"
 
     def _data_obj(self):
-        return self.extras().Multirange(
-            [
-                self.extras().Range(
-                    decimal.Decimal("1.0"), decimal.Decimal("2.0")
-                ),
-                self.extras().Range(
-                    decimal.Decimal("3.0"), decimal.Decimal("5.0")
-                ),
-                self.extras().Range(
-                    decimal.Decimal("9.0"), decimal.Decimal("12.0")
-                ),
-            ]
-        )
+        return [
+            Range(decimal.Decimal("1.0"), decimal.Decimal("2.0")),
+            Range(decimal.Decimal("3.0"), decimal.Decimal("5.0")),
+            Range(decimal.Decimal("9.0"), decimal.Decimal("12.0")),
+        ]
 
 
 class _DateMultiRangeTests:
@@ -4267,16 +4247,10 @@ class _DateMultiRangeTests:
         return "{[2013-03-23,2013-03-24), [2014-05-23,2014-05-24)}"
 
     def _data_obj(self):
-        return self.extras().Multirange(
-            [
-                self.extras().Range(
-                    datetime.date(2013, 3, 23), datetime.date(2013, 3, 24)
-                ),
-                self.extras().Range(
-                    datetime.date(2014, 5, 23), datetime.date(2014, 5, 24)
-                ),
-            ]
-        )
+        return [
+            Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 24)),
+            Range(datetime.date(2014, 5, 23), datetime.date(2014, 5, 24)),
+        ]
 
 
 class _DateTimeMultiRangeTests:
@@ -4291,18 +4265,16 @@ class _DateTimeMultiRangeTests:
         )
 
     def _data_obj(self):
-        return self.extras().Multirange(
-            [
-                self.extras().Range(
-                    datetime.datetime(2013, 3, 23, 14, 30),
-                    datetime.datetime(2013, 3, 23, 23, 30),
-                ),
-                self.extras().Range(
-                    datetime.datetime(2014, 5, 23, 14, 30),
-                    datetime.datetime(2014, 5, 23, 23, 30),
-                ),
-            ]
-        )
+        return [
+            Range(
+                datetime.datetime(2013, 3, 23, 14, 30),
+                datetime.datetime(2013, 3, 23, 23, 30),
+            ),
+            Range(
+                datetime.datetime(2014, 5, 23, 14, 30),
+                datetime.datetime(2014, 5, 23, 23, 30),
+            ),
+        ]
 
 
 class _DateTimeTZMultiRangeTests:
@@ -4344,12 +4316,10 @@ class _DateTimeTZMultiRangeTests:
         )
 
     def _data_obj(self):
-        return self.extras().Multirange(
-            [
-                self.extras().Range(*self.tstzs()),
-                self.extras().Range(*self.tstzs_delta()),
-            ]
-        )
+        return [
+            Range(*self.tstzs()),
+            Range(*self.tstzs_delta()),
+        ]
 
 
 class Int4MultiRangeCompilationTest(
index 8cd586efd41d80da4751f73583aa4eb7a0678632..ef779cf2d6ee9e7c7ebf1c34d05ac95d4da57936 100644 (file)
@@ -1353,17 +1353,11 @@ class DefaultRequirements(SuiteRequirements):
 
     @property
     def range_types(self):
-        def check_range_types(config):
-            if not self.any_psycopg_compatibility.enabled:
-                return False
-            try:
-                with config.db.connect() as conn:
-                    conn.exec_driver_sql("select '[1,2)'::int4range;").scalar()
-                return True
-            except Exception:
-                return False
+        return only_on(["+psycopg2", "+psycopg", "+asyncpg"])
 
-        return only_if(check_range_types)
+    @property
+    def multirange_types(self):
+        return only_on(["+psycopg", "+asyncpg"]) + only_on("postgresql >= 14")
 
     @property
     def async_dialect(self):