Added additional type-detection for the new PostgreSQL
:class:`_postgresql.Range` type, where previous cases that allowed the
psycopg2-native range objects to be received directly by the DBAPI without
SQLAlchemy intercepting them stopped working, as we now have our own value
object. The :class:`_postgresql.Range` object has been enhanced such that
SQLAlchemy Core detects it in otherwise ambiguous situations (such as
comparison to dates) and applies appropriate bind handlers. Pull request
courtesy Lele Gaifax.
Fixes: #8884
Closes: #8886
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8886
Pull-request-sha:
6e95e08a30597d3735ab38f2f1a2ccabd968852c
Change-Id: I3ca277c826dcf4b5644f44eb251345b439a84ee4
--- /dev/null
+.. change::
+ :tags: bug, postgresql
+ :tickets: 8884
+
+ Added additional type-detection for the new PostgreSQL
+ :class:`_postgresql.Range` type, where previous cases that allowed the
+ psycopg2-native range objects to be received directly by the DBAPI without
+ SQLAlchemy intercepting them stopped working, as we now have our own value
+ object. The :class:`_postgresql.Range` object has been enhanced such that
+ SQLAlchemy Core detects it in otherwise ambiguous situations (such as
+ comparison to dates) and applies appropriate bind handlers. Pull request
+ courtesy Lele Gaifax.
from datetime import date
from datetime import datetime
from datetime import timedelta
+from decimal import Decimal
from typing import Any
from typing import Generic
from typing import Optional
def __bool__(self) -> bool:
return self.empty
+ @property
+ def __sa_type_engine__(self):
+ return AbstractRange()
+
def _contains_value(self, value: _T) -> bool:
"Check whether this range contains the given `value`."
else:
return super().adapt(impltype)
+ def _resolve_for_literal(self, value):
+ spec = value.lower if value.lower is not None else value.upper
+
+ if isinstance(spec, int):
+ return INT8RANGE()
+ elif isinstance(spec, (Decimal, float)):
+ return NUMRANGE()
+ elif isinstance(spec, datetime):
+ return TSRANGE() if not spec.tzinfo else TSTZRANGE()
+ elif isinstance(spec, date):
+ return DATERANGE()
+ else:
+ # empty Range, SQL datatype can't be determined here
+ return sqltypes.NULLTYPE
+
class comparator_factory(sqltypes.Concatenable.Comparator):
"""Define comparison operations for range types."""
def _resolve_value_to_type(value: Any) -> TypeEngine[Any]:
_result_type = _type_map_get(type(value), False)
+
+ if _result_type is False:
+ _result_type = getattr(value, "__sa_type_engine__", False)
+
if _result_type is False:
# use inspect() to detect SQLAlchemy built-in
# objects.
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.postgresql import JSONPATH
+from sqlalchemy.dialects.postgresql import Range
from sqlalchemy.dialects.postgresql import TSRANGE
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
"AS jsonb_path_exists_1 FROM data",
)
+ def test_custom_object_hook(self):
+ # See issue #8884
+ from datetime import date
+
+ usages = table(
+ "usages",
+ column("id", Integer),
+ column("date", Date),
+ column("amount", Integer),
+ )
+ period = Range(date(2022, 1, 1), (2023, 1, 1))
+ stmt = select(func.sum(usages.c.amount)).where(
+ usages.c.date.op("<@")(period)
+ )
+ self.assert_compile(
+ stmt,
+ "SELECT sum(usages.amount) AS sum_1 FROM usages "
+ "WHERE usages.date <@ %(date_1)s::DATERANGE",
+ )
+
class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL):
__dialect__ = postgresql.dialect()
],
)
+ @testing.variation("secondary_adapt", [True, False])
+ @testing.variation("expression_type", ["literal", "right_side"])
+ def test_value_level_bind_hooks(
+ self, connection, metadata, secondary_adapt, expression_type
+ ):
+ """test new feature added in #8884, allowing custom value objects
+ to indicate the SQL type they should resolve towards.
+
+ """
+
+ class MyFoobarType(types.UserDefinedType):
+ if secondary_adapt:
+
+ def _resolve_for_literal(self, value):
+ return String(value.length)
+
+ class Widget:
+ def __init__(self, length):
+ self.length = length
+
+ @property
+ def __sa_type_engine__(self):
+ return MyFoobarType()
+
+ if expression_type.literal:
+ expr = literal(Widget(52))
+ elif expression_type.right_side:
+ expr = (column("x", Integer) == Widget(52)).right
+ else:
+ assert False
+
+ if secondary_adapt:
+ is_(expr.type._type_affinity, String)
+ eq_(expr.type.length, 52)
+ else:
+ is_(expr.type._type_affinity, MyFoobarType)
+
def test_grouped_bind_adapt(self):
test_table = self.tables.test