]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add value-level hooks for SQL type detection; apply to Range
authorLele Gaifax <lele@metapensiero.it>
Sun, 27 Nov 2022 16:28:51 +0000 (11:28 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 29 Nov 2022 22:11:38 +0000 (17:11 -0500)
Added additional type-detection for the new PostgreSQL
:class:`_postgresql.Range` type, where previous cases that allowed the
psycopg2-native range objects to be received directly by the DBAPI without
SQLAlchemy intercepting them stopped working, as we now have our own value
object. The :class:`_postgresql.Range` object has been enhanced such that
SQLAlchemy Core detects it in otherwise ambiguous situations (such as
comparison to dates) and applies appropriate bind handlers. Pull request
courtesy Lele Gaifax.

Fixes: #8884
Closes: #8886
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8886
Pull-request-sha: 6e95e08a30597d3735ab38f2f1a2ccabd968852c

Change-Id: I3ca277c826dcf4b5644f44eb251345b439a84ee4

doc/build/changelog/unreleased_20/8884.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/ranges.py
lib/sqlalchemy/sql/sqltypes.py
test/dialect/postgresql/test_compiler.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/8884.rst b/doc/build/changelog/unreleased_20/8884.rst
new file mode 100644 (file)
index 0000000..0edbb95
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 8884
+
+    Added additional type-detection for the new PostgreSQL
+    :class:`_postgresql.Range` type, where previous cases that allowed the
+    psycopg2-native range objects to be received directly by the DBAPI without
+    SQLAlchemy intercepting them stopped working, as we now have our own value
+    object. The :class:`_postgresql.Range` object has been enhanced such that
+    SQLAlchemy Core detects it in otherwise ambiguous situations (such as
+    comparison to dates) and applies appropriate bind handlers. Pull request
+    courtesy Lele Gaifax.
index a4c39d0639f6282acf3d120678b5ab5d1c87d345..6f13d462ad28349f997f9f457a01b2a5bb33412c 100644 (file)
@@ -11,6 +11,7 @@ import dataclasses
 from datetime import date
 from datetime import datetime
 from datetime import timedelta
+from decimal import Decimal
 from typing import Any
 from typing import Generic
 from typing import Optional
@@ -84,6 +85,10 @@ class Range(Generic[_T]):
     def __bool__(self) -> bool:
         return self.empty
 
+    @property
+    def __sa_type_engine__(self):
+        return AbstractRange()
+
     def _contains_value(self, value: _T) -> bool:
         "Check whether this range contains the given `value`."
 
@@ -622,6 +627,21 @@ class AbstractRange(sqltypes.TypeEngine):
         else:
             return super().adapt(impltype)
 
+    def _resolve_for_literal(self, value):
+        spec = value.lower if value.lower is not None else value.upper
+
+        if isinstance(spec, int):
+            return INT8RANGE()
+        elif isinstance(spec, (Decimal, float)):
+            return NUMRANGE()
+        elif isinstance(spec, datetime):
+            return TSRANGE() if not spec.tzinfo else TSTZRANGE()
+        elif isinstance(spec, date):
+            return DATERANGE()
+        else:
+            # empty Range, SQL datatype can't be determined here
+            return sqltypes.NULLTYPE
+
     class comparator_factory(sqltypes.Concatenable.Comparator):
         """Define comparison operations for range types."""
 
index 624b7d16efc735e11e48408c3f17e3c6c8386c9f..308c233e4a17fd88decdac600e6a1434397ce30b 100644 (file)
@@ -3682,6 +3682,10 @@ _type_map_get = _type_map.get
 
 def _resolve_value_to_type(value: Any) -> TypeEngine[Any]:
     _result_type = _type_map_get(type(value), False)
+
+    if _result_type is False:
+        _result_type = getattr(value, "__sa_type_engine__", False)
+
     if _result_type is False:
         # use inspect() to detect SQLAlchemy built-in
         # objects.
index 431cd7ded156697830b9374467a6187a09e0a5de..ee3372c74917a774b6f818dafda23da2ae17d41b 100644 (file)
@@ -42,6 +42,7 @@ from sqlalchemy.dialects.postgresql import ExcludeConstraint
 from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.dialects.postgresql import JSONPATH
+from sqlalchemy.dialects.postgresql import Range
 from sqlalchemy.dialects.postgresql import TSRANGE
 from sqlalchemy.dialects.postgresql.base import PGDialect
 from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
@@ -2397,6 +2398,26 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "AS jsonb_path_exists_1 FROM data",
         )
 
+    def test_custom_object_hook(self):
+        # See issue #8884
+        from datetime import date
+
+        usages = table(
+            "usages",
+            column("id", Integer),
+            column("date", Date),
+            column("amount", Integer),
+        )
+        period = Range(date(2022, 1, 1), (2023, 1, 1))
+        stmt = select(func.sum(usages.c.amount)).where(
+            usages.c.date.op("<@")(period)
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT sum(usages.amount) AS sum_1 FROM usages "
+            "WHERE usages.date <@ %(date_1)s::DATERANGE",
+        )
+
 
 class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL):
     __dialect__ = postgresql.dialect()
index d1b32186e90c105ce613dbc371cfffcca8345e47..91413ff3597efdcc84c46b93a1854219f50be2ab 100644 (file)
@@ -3293,6 +3293,43 @@ class ExpressionTest(
             ],
         )
 
+    @testing.variation("secondary_adapt", [True, False])
+    @testing.variation("expression_type", ["literal", "right_side"])
+    def test_value_level_bind_hooks(
+        self, connection, metadata, secondary_adapt, expression_type
+    ):
+        """test new feature added in #8884, allowing custom value objects
+        to indicate the SQL type they should resolve towards.
+
+        """
+
+        class MyFoobarType(types.UserDefinedType):
+            if secondary_adapt:
+
+                def _resolve_for_literal(self, value):
+                    return String(value.length)
+
+        class Widget:
+            def __init__(self, length):
+                self.length = length
+
+            @property
+            def __sa_type_engine__(self):
+                return MyFoobarType()
+
+        if expression_type.literal:
+            expr = literal(Widget(52))
+        elif expression_type.right_side:
+            expr = (column("x", Integer) == Widget(52)).right
+        else:
+            assert False
+
+        if secondary_adapt:
+            is_(expr.type._type_affinity, String)
+            eq_(expr.type.length, 52)
+        else:
+            is_(expr.type._type_affinity, MyFoobarType)
+
     def test_grouped_bind_adapt(self):
         test_table = self.tables.test