From 9ae645d5d1a8cc7732a6d335be6205d0b21e31b1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 20 Sep 2022 12:21:14 -0400 Subject: [PATCH] auto-cast PG range types Range type handling has been enhanced so that it automatically renders type casts, so that in-place round trips for statements that don't provide the database with any context don't require the :func:`_sql.cast` construct to be explicit for the database to know the desired type. Change-Id: Id630b726f8a23059dd2f4cbc410bf5229d89cbfb References: #8540 --- doc/build/changelog/unreleased_20/7156.rst | 8 +- lib/sqlalchemy/dialects/postgresql/asyncpg.py | 4 +- lib/sqlalchemy/dialects/postgresql/psycopg.py | 4 +- .../dialects/postgresql/psycopg2.py | 2 +- lib/sqlalchemy/dialects/postgresql/ranges.py | 39 ++++++++ test/dialect/postgresql/test_types.py | 89 +++++++++++++++++++ 6 files changed, 140 insertions(+), 6 deletions(-) diff --git a/doc/build/changelog/unreleased_20/7156.rst b/doc/build/changelog/unreleased_20/7156.rst index 2d409521f8..cd81c9a6c1 100644 --- a/doc/build/changelog/unreleased_20/7156.rst +++ b/doc/build/changelog/unreleased_20/7156.rst @@ -1,6 +1,6 @@ .. change:: :tags: postgresql, usecase - :tickets: 7156 + :tickets: 7156, 8540 Adds support for PostgreSQL multirange types, introduced in PostgreSQL 14. Support for PostgreSQL ranges and multiranges has now been generalized to @@ -9,6 +9,12 @@ that's constructor-compatible with the previously used psycopg2 object. See the new documentation for usage patterns. + In addition, range type handling has been enhanced so that it automatically + renders type casts, so that in-place round trips for statements that don't + provide the database with any context don't require the :func:`_sql.cast` + construct to be explicit for the database to know the desired type + (discussed at :ticket:`8540`). + Thanks very much to @zeeeeeb for the pull request implementing and testing the new datatypes and psycopg support. diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 4cc04d20ae..c953d34471 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -291,7 +291,7 @@ class AsyncpgCHAR(sqltypes.CHAR): render_bind_cast = True -class _AsyncpgRange(ranges.AbstractRange): +class _AsyncpgRange(ranges.AbstractRangeImpl): def bind_processor(self, dialect): Range = dialect.dbapi.asyncpg.Range @@ -326,7 +326,7 @@ class _AsyncpgRange(ranges.AbstractRange): return to_range -class _AsyncpgMultiRange(ranges.AbstractMultiRange): +class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl): def bind_processor(self, dialect): Range = dialect.dbapi.asyncpg.Range diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index 371bf2bc23..7ca274e2c7 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -162,7 +162,7 @@ class _PGBoolean(sqltypes.Boolean): render_bind_cast = True -class _PsycopgRange(ranges.AbstractRange): +class _PsycopgRange(ranges.AbstractRangeImpl): def bind_processor(self, dialect): Range = cast(PGDialect_psycopg, dialect)._psycopg_Range @@ -191,7 +191,7 @@ class _PsycopgRange(ranges.AbstractRange): return to_range -class _PsycopgMultiRange(ranges.AbstractMultiRange): +class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl): def bind_processor(self, dialect): Range = cast(PGDialect_psycopg, dialect)._psycopg_Range Multirange = cast(PGDialect_psycopg, dialect)._psycopg_Multirange diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 5dcd449cab..a01f20e99f 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -507,7 +507,7 @@ class _PGJSONB(JSONB): return None -class _Psycopg2Range(ranges.AbstractRange): +class _Psycopg2Range(ranges.AbstractRangeImpl): _psycopg2_range_cls = "none" def bind_processor(self, dialect): diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index edbe165d98..327feb4092 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -91,6 +91,35 @@ class AbstractRange(sqltypes.TypeEngine): """ # noqa: E501 + render_bind_cast = True + + def adapt(self, impltype): + """dynamically adapt a range type to an abstract impl. + + For example ``INT4RANGE().adapt(_Psycopg2NumericRange)`` should + produce a type that will have ``_Psycopg2NumericRange`` behaviors + and also render as ``INT4RANGE`` in SQL and DDL. + + """ + if issubclass(impltype, AbstractRangeImpl): + # two ways to do this are: 1. create a new type on the fly + # or 2. have AbstractRangeImpl(visit_name) constructor and a + # visit_abstract_range_impl() method in the PG compiler. + # I'm choosing #1 as the resulting type object + # will then make use of the same mechanics + # as if we had made all these sub-types explicitly, and will + # also look more obvious under pdb etc. + # The adapt() operation here is cached per type-class-per-dialect, + # so is not much of a performance concern + visit_name = self.__visit_name__ + return type( + f"{visit_name}RangeImpl", + (impltype, self.__class__), + {"__visit_name__": visit_name}, + )() + else: + return super().adapt(impltype) + class comparator_factory(sqltypes.Concatenable.Comparator): """Define comparison operations for range types.""" @@ -165,10 +194,20 @@ class AbstractRange(sqltypes.TypeEngine): return self.expr.op("+")(other) +class AbstractRangeImpl(AbstractRange): + """marker for AbstractRange that will apply a subclass-specific + adaptation""" + + class AbstractMultiRange(AbstractRange): """base for PostgreSQL MULTIRANGE types""" +class AbstractMultiRangeImpl(AbstractRangeImpl, AbstractMultiRange): + """marker for AbstractRange that will apply a subclass-specific + adaptation""" + + class INT4RANGE(AbstractRange): """Represent the PostgreSQL INT4RANGE type.""" diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index b5c20bd8d4..1f93a40235 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -3690,6 +3690,13 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): sqltypes.BOOLEANTYPE, ) + def test_where_equal_obj(self): + self._test_clause( + self.col == self._data_obj(), + f"data_table.range = %(range_1)s::{self._col_str}", + sqltypes.BOOLEANTYPE, + ) + def test_where_not_equal(self): self._test_clause( self.col != self._data_str(), @@ -3697,6 +3704,13 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): sqltypes.BOOLEANTYPE, ) + def test_where_not_equal_obj(self): + self._test_clause( + self.col != self._data_obj(), + f"data_table.range <> %(range_1)s::{self._col_str}", + sqltypes.BOOLEANTYPE, + ) + def test_where_is_null(self): self._test_clause( self.col == None, "data_table.range IS NULL", sqltypes.BOOLEANTYPE @@ -3744,6 +3758,13 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): sqltypes.BOOLEANTYPE, ) + def test_contains_obj(self): + self._test_clause( + self.col.contains(self._data_obj()), + f"data_table.range @> %(range_1)s::{self._col_str}", + sqltypes.BOOLEANTYPE, + ) + def test_contained_by(self): self._test_clause( self.col.contained_by(self._data_str()), @@ -3840,6 +3861,26 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): ) cls.col = table.c.range + def test_auto_cast_back_to_type(self, connection): + """test that a straight pass of the range type without any context + will send appropriate casting info so that the driver can round + trip it. + + This doesn't happen in general across other backends and not for + types like JSON etc., although perhaps it should, as we now have + pretty straightforward infrastructure to turn it on; asyncpg + for example does cast JSONs now in place. But that's a + bigger issue; for PG ranges it's likely useful to do this for + PG backends as this is a fairly narrow use case. + + Brought up in #8540. + + """ + data_obj = self._data_obj() + stmt = select(literal(data_obj, type_=self._col_type)) + round_trip = connection.scalar(stmt) + eq_(round_trip, data_obj) + def test_actual_type(self): eq_(str(self._col_type()), self._col_str) @@ -4093,6 +4134,13 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): sqltypes.BOOLEANTYPE, ) + def test_where_equal_obj(self): + self._test_clause( + self.col == self._data_obj(), + f"data_table.multirange = %(multirange_1)s::{self._col_str}", + sqltypes.BOOLEANTYPE, + ) + def test_where_not_equal(self): self._test_clause( self.col != self._data_str(), @@ -4100,6 +4148,13 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): sqltypes.BOOLEANTYPE, ) + def test_where_not_equal_obj(self): + self._test_clause( + self.col != self._data_obj(), + f"data_table.multirange <> %(multirange_1)s::{self._col_str}", + sqltypes.BOOLEANTYPE, + ) + def test_where_is_null(self): self._test_clause( self.col == None, @@ -4156,6 +4211,13 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): sqltypes.BOOLEANTYPE, ) + def test_contained_by_obj(self): + self._test_clause( + self.col.contained_by(self._data_obj()), + f"data_table.multirange <@ %(multirange_1)s::{self._col_str}", + sqltypes.BOOLEANTYPE, + ) + def test_overlaps(self): self._test_clause( self.col.overlaps(self._data_str()), @@ -4208,6 +4270,13 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): sqltypes.BOOLEANTYPE, ) + def test_adjacent_to_obj(self): + self._test_clause( + self.col.adjacent_to(self._data_obj()), + f"data_table.multirange -|- %(multirange_1)s::{self._col_str}", + sqltypes.BOOLEANTYPE, + ) + def test_union(self): self._test_clause( self.col + self.col, @@ -4245,6 +4314,26 @@ class _MultiRangeTypeRoundTrip(fixtures.TablesTest): ) cls.col = table.c.range + def test_auto_cast_back_to_type(self, connection): + """test that a straight pass of the range type without any context + will send appropriate casting info so that the driver can round + trip it. + + This doesn't happen in general across other backends and not for + types like JSON etc., although perhaps it should, as we now have + pretty straightforward infrastructure to turn it on; asyncpg + for example does cast JSONs now in place. But that's a + bigger issue; for PG ranges it's likely useful to do this for + PG backends as this is a fairly narrow use case. + + Brought up in #8540. + + """ + data_obj = self._data_obj() + stmt = select(literal(data_obj, type_=self._col_type)) + round_trip = connection.scalar(stmt) + eq_(round_trip, data_obj) + def test_actual_type(self): eq_(str(self._col_type()), self._col_str) -- 2.47.2