]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Create a compiler hook to pick up custom objects in expressions 8886/head
authorLele Gaifax <lele@metapensiero.it>
Sun, 27 Nov 2022 16:06:49 +0000 (17:06 +0100)
committerLele Gaifax <lele@metapensiero.it>
Sun, 27 Nov 2022 16:06:49 +0000 (17:06 +0100)
This simply applies what has been suggested by Mike in issue #8884, and
seems to work. I guess more work is needed, for example mentioning
and/or documenting the new __sa_type_engine__ hook.

lib/sqlalchemy/dialects/postgresql/ranges.py
lib/sqlalchemy/sql/sqltypes.py
test/dialect/postgresql/test_compiler.py

index a4c39d0639f6282acf3d120678b5ab5d1c87d345..6f13d462ad28349f997f9f457a01b2a5bb33412c 100644 (file)
@@ -11,6 +11,7 @@ import dataclasses
 from datetime import date
 from datetime import datetime
 from datetime import timedelta
+from decimal import Decimal
 from typing import Any
 from typing import Generic
 from typing import Optional
@@ -84,6 +85,10 @@ class Range(Generic[_T]):
     def __bool__(self) -> bool:
         return self.empty
 
+    @property
+    def __sa_type_engine__(self):
+        return AbstractRange()
+
     def _contains_value(self, value: _T) -> bool:
         "Check whether this range contains the given `value`."
 
@@ -622,6 +627,21 @@ class AbstractRange(sqltypes.TypeEngine):
         else:
             return super().adapt(impltype)
 
+    def _resolve_for_literal(self, value):
+        spec = value.lower if value.lower is not None else value.upper
+
+        if isinstance(spec, int):
+            return INT8RANGE()
+        elif isinstance(spec, (Decimal, float)):
+            return NUMRANGE()
+        elif isinstance(spec, datetime):
+            return TSRANGE() if not spec.tzinfo else TSTZRANGE()
+        elif isinstance(spec, date):
+            return DATERANGE()
+        else:
+            # empty Range, SQL datatype can't be determined here
+            return sqltypes.NULLTYPE
+
     class comparator_factory(sqltypes.Concatenable.Comparator):
         """Define comparison operations for range types."""
 
index 624b7d16efc735e11e48408c3f17e3c6c8386c9f..308c233e4a17fd88decdac600e6a1434397ce30b 100644 (file)
@@ -3682,6 +3682,10 @@ _type_map_get = _type_map.get
 
 def _resolve_value_to_type(value: Any) -> TypeEngine[Any]:
     _result_type = _type_map_get(type(value), False)
+
+    if _result_type is False:
+        _result_type = getattr(value, "__sa_type_engine__", False)
+
     if _result_type is False:
         # use inspect() to detect SQLAlchemy built-in
         # objects.
index 431cd7ded156697830b9374467a6187a09e0a5de..ee3372c74917a774b6f818dafda23da2ae17d41b 100644 (file)
@@ -42,6 +42,7 @@ from sqlalchemy.dialects.postgresql import ExcludeConstraint
 from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.dialects.postgresql import JSONPATH
+from sqlalchemy.dialects.postgresql import Range
 from sqlalchemy.dialects.postgresql import TSRANGE
 from sqlalchemy.dialects.postgresql.base import PGDialect
 from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
@@ -2397,6 +2398,26 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "AS jsonb_path_exists_1 FROM data",
         )
 
+    def test_custom_object_hook(self):
+        # See issue #8884
+        from datetime import date
+
+        usages = table(
+            "usages",
+            column("id", Integer),
+            column("date", Date),
+            column("amount", Integer),
+        )
+        period = Range(date(2022, 1, 1), (2023, 1, 1))
+        stmt = select(func.sum(usages.c.amount)).where(
+            usages.c.date.op("<@")(period)
+        )
+        self.assert_compile(
+            stmt,
+            "SELECT sum(usages.amount) AS sum_1 FROM usages "
+            "WHERE usages.date <@ %(date_1)s::DATERANGE",
+        )
+
 
 class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL):
     __dialect__ = postgresql.dialect()