]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
auto-cast PG range types
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Sep 2022 16:21:14 +0000 (12:21 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Sep 2022 16:34:46 +0000 (12:34 -0400)
Range type handling has been enhanced so that it automatically
renders type casts, so that in-place round trips for statements that don't
provide the database with any context don't require the :func:`_sql.cast`
construct to be explicit for the database to know the desired type.

Change-Id: Id630b726f8a23059dd2f4cbc410bf5229d89cbfb
References: #8540

doc/build/changelog/unreleased_20/7156.rst
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

index 2d409521f8949fb10e37df10931a583d90a1ede2..cd81c9a6c1b3db02e88602313268e50fab783970 100644 (file)
@@ -1,6 +1,6 @@
 .. change::
     :tags: postgresql, usecase
-    :tickets: 7156
+    :tickets: 7156, 8540
 
     Adds support for PostgreSQL multirange types, introduced in PostgreSQL 14.
     Support for PostgreSQL ranges and multiranges has now been generalized to
@@ -9,6 +9,12 @@
     that's constructor-compatible with the previously used psycopg2 object. See
     the new documentation for usage patterns.
 
+    In addition, range type handling has been enhanced so that it automatically
+    renders type casts, so that in-place round trips for statements that don't
+    provide the database with any context don't require the :func:`_sql.cast`
+    construct to be explicit for the database to know the desired type
+    (discussed at :ticket:`8540`).
+
     Thanks very much to @zeeeeeb for the pull request implementing and testing
     the new datatypes and psycopg support.
 
index 4cc04d20aee5be5f6ab5d03a6aadffd1ae33d885..c953d34471a8a85c873c177a2fc768d1f66719ad 100644 (file)
@@ -291,7 +291,7 @@ class AsyncpgCHAR(sqltypes.CHAR):
     render_bind_cast = True
 
 
-class _AsyncpgRange(ranges.AbstractRange):
+class _AsyncpgRange(ranges.AbstractRangeImpl):
     def bind_processor(self, dialect):
         Range = dialect.dbapi.asyncpg.Range
 
@@ -326,7 +326,7 @@ class _AsyncpgRange(ranges.AbstractRange):
         return to_range
 
 
-class _AsyncpgMultiRange(ranges.AbstractMultiRange):
+class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
     def bind_processor(self, dialect):
         Range = dialect.dbapi.asyncpg.Range
 
index 371bf2bc232d75f5509a306316091d6e91f0edfa..7ca274e2c779bb4def5022b356fbe54ad1937910 100644 (file)
@@ -162,7 +162,7 @@ class _PGBoolean(sqltypes.Boolean):
     render_bind_cast = True
 
 
-class _PsycopgRange(ranges.AbstractRange):
+class _PsycopgRange(ranges.AbstractRangeImpl):
     def bind_processor(self, dialect):
         Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
 
@@ -191,7 +191,7 @@ class _PsycopgRange(ranges.AbstractRange):
         return to_range
 
 
-class _PsycopgMultiRange(ranges.AbstractMultiRange):
+class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl):
     def bind_processor(self, dialect):
         Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
         Multirange = cast(PGDialect_psycopg, dialect)._psycopg_Multirange
index 5dcd449cab8cbcbe61529f8b68cd7b4c0527f028..a01f20e99fbb71ae2bb0d68a81cef3d1b8f9f485 100644 (file)
@@ -507,7 +507,7 @@ class _PGJSONB(JSONB):
         return None
 
 
-class _Psycopg2Range(ranges.AbstractRange):
+class _Psycopg2Range(ranges.AbstractRangeImpl):
     _psycopg2_range_cls = "none"
 
     def bind_processor(self, dialect):
index edbe165d987aa4031d06c9380cfd20ff94443bf4..327feb4092b0e700c073e5d132957d7e7cb03582 100644 (file)
@@ -91,6 +91,35 @@ class AbstractRange(sqltypes.TypeEngine):
 
     """  # noqa: E501
 
+    render_bind_cast = True
+
+    def adapt(self, impltype):
+        """dynamically adapt a range type to an abstract impl.
+
+        For example ``INT4RANGE().adapt(_Psycopg2NumericRange)`` should
+        produce a type that will have ``_Psycopg2NumericRange`` behaviors
+        and also render as ``INT4RANGE`` in SQL and DDL.
+
+        """
+        if issubclass(impltype, AbstractRangeImpl):
+            # two ways to do this are:  1. create a new type on the fly
+            # or 2. have AbstractRangeImpl(visit_name) constructor and a
+            # visit_abstract_range_impl() method in the PG compiler.
+            # I'm choosing #1 as the resulting type object
+            # will then make use of the same mechanics
+            # as if we had made all these sub-types explicitly, and will
+            # also look more obvious under pdb etc.
+            # The adapt() operation here is cached per type-class-per-dialect,
+            # so is not much of a performance concern
+            visit_name = self.__visit_name__
+            return type(
+                f"{visit_name}RangeImpl",
+                (impltype, self.__class__),
+                {"__visit_name__": visit_name},
+            )()
+        else:
+            return super().adapt(impltype)
+
     class comparator_factory(sqltypes.Concatenable.Comparator):
         """Define comparison operations for range types."""
 
@@ -165,10 +194,20 @@ class AbstractRange(sqltypes.TypeEngine):
             return self.expr.op("+")(other)
 
 
+class AbstractRangeImpl(AbstractRange):
+    """marker for AbstractRange that will apply a subclass-specific
+    adaptation"""
+
+
 class AbstractMultiRange(AbstractRange):
     """base for PostgreSQL MULTIRANGE types"""
 
 
+class AbstractMultiRangeImpl(AbstractRangeImpl, AbstractMultiRange):
+    """marker for AbstractRange that will apply a subclass-specific
+    adaptation"""
+
+
 class INT4RANGE(AbstractRange):
     """Represent the PostgreSQL INT4RANGE type."""
 
index b5c20bd8d42c8cee53059111d7147df6bbd0ea9c..1f93a40235f70cfdd9c7877f73719a108b184f37 100644 (file)
@@ -3690,6 +3690,13 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
             sqltypes.BOOLEANTYPE,
         )
 
+    def test_where_equal_obj(self):
+        self._test_clause(
+            self.col == self._data_obj(),
+            f"data_table.range = %(range_1)s::{self._col_str}",
+            sqltypes.BOOLEANTYPE,
+        )
+
     def test_where_not_equal(self):
         self._test_clause(
             self.col != self._data_str(),
@@ -3697,6 +3704,13 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
             sqltypes.BOOLEANTYPE,
         )
 
+    def test_where_not_equal_obj(self):
+        self._test_clause(
+            self.col != self._data_obj(),
+            f"data_table.range <> %(range_1)s::{self._col_str}",
+            sqltypes.BOOLEANTYPE,
+        )
+
     def test_where_is_null(self):
         self._test_clause(
             self.col == None, "data_table.range IS NULL", sqltypes.BOOLEANTYPE
@@ -3744,6 +3758,13 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
             sqltypes.BOOLEANTYPE,
         )
 
+    def test_contains_obj(self):
+        self._test_clause(
+            self.col.contains(self._data_obj()),
+            f"data_table.range @> %(range_1)s::{self._col_str}",
+            sqltypes.BOOLEANTYPE,
+        )
+
     def test_contained_by(self):
         self._test_clause(
             self.col.contained_by(self._data_str()),
@@ -3840,6 +3861,26 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
         )
         cls.col = table.c.range
 
+    def test_auto_cast_back_to_type(self, connection):
+        """test that a straight pass of the range type without any context
+        will send appropriate casting info so that the driver can round
+        trip it.
+
+        This doesn't happen in general across other backends and not for
+        types like JSON etc., although perhaps it should, as we now have
+        pretty straightforward infrastructure to turn it on; asyncpg
+        for example does cast JSONs now in place.  But that's a
+        bigger issue; for PG ranges it's likely useful to do this for
+        PG backends as this is a fairly narrow use case.
+
+        Brought up in #8540.
+
+        """
+        data_obj = self._data_obj()
+        stmt = select(literal(data_obj, type_=self._col_type))
+        round_trip = connection.scalar(stmt)
+        eq_(round_trip, data_obj)
+
     def test_actual_type(self):
         eq_(str(self._col_type()), self._col_str)
 
@@ -4093,6 +4134,13 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
             sqltypes.BOOLEANTYPE,
         )
 
+    def test_where_equal_obj(self):
+        self._test_clause(
+            self.col == self._data_obj(),
+            f"data_table.multirange = %(multirange_1)s::{self._col_str}",
+            sqltypes.BOOLEANTYPE,
+        )
+
     def test_where_not_equal(self):
         self._test_clause(
             self.col != self._data_str(),
@@ -4100,6 +4148,13 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
             sqltypes.BOOLEANTYPE,
         )
 
+    def test_where_not_equal_obj(self):
+        self._test_clause(
+            self.col != self._data_obj(),
+            f"data_table.multirange <> %(multirange_1)s::{self._col_str}",
+            sqltypes.BOOLEANTYPE,
+        )
+
     def test_where_is_null(self):
         self._test_clause(
             self.col == None,
@@ -4156,6 +4211,13 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
             sqltypes.BOOLEANTYPE,
         )
 
+    def test_contained_by_obj(self):
+        self._test_clause(
+            self.col.contained_by(self._data_obj()),
+            f"data_table.multirange <@ %(multirange_1)s::{self._col_str}",
+            sqltypes.BOOLEANTYPE,
+        )
+
     def test_overlaps(self):
         self._test_clause(
             self.col.overlaps(self._data_str()),
@@ -4208,6 +4270,13 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
             sqltypes.BOOLEANTYPE,
         )
 
+    def test_adjacent_to_obj(self):
+        self._test_clause(
+            self.col.adjacent_to(self._data_obj()),
+            f"data_table.multirange -|- %(multirange_1)s::{self._col_str}",
+            sqltypes.BOOLEANTYPE,
+        )
+
     def test_union(self):
         self._test_clause(
             self.col + self.col,
@@ -4245,6 +4314,26 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest):
         )
         cls.col = table.c.range
 
+    def test_auto_cast_back_to_type(self, connection):
+        """test that a straight pass of the range type without any context
+        will send appropriate casting info so that the driver can round
+        trip it.
+
+        This doesn't happen in general across other backends and not for
+        types like JSON etc., although perhaps it should, as we now have
+        pretty straightforward infrastructure to turn it on; asyncpg
+        for example does cast JSONs now in place.  But that's a
+        bigger issue; for PG ranges it's likely useful to do this for
+        PG backends as this is a fairly narrow use case.
+
+        Brought up in #8540.
+
+        """
+        data_obj = self._data_obj()
+        stmt = select(literal(data_obj, type_=self._col_type))
+        round_trip = connection.scalar(stmt)
+        eq_(round_trip, data_obj)
+
     def test_actual_type(self):
         eq_(str(self._col_type()), self._col_str)