From: Jim Bosch Date: Mon, 13 Nov 2023 15:26:54 +0000 (-0500) Subject: Fix typing generics in PostgreSQL range types. X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F10625%2Fhead;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fix typing generics in PostgreSQL range types. AbstractRange is parameterized on the bounds type of the range, not the range type; this led to e.g. Range[Range[int]] appearing as the type of some expressions that were actually Range[int]. --- diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index f1c29897d0..15728abada 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -151,7 +151,7 @@ class Range(Generic[_T]): return not self.empty and self.upper is None @property - def __sa_type_engine__(self) -> AbstractRange[Range[_T]]: + def __sa_type_engine__(self) -> AbstractRange[_T]: return AbstractRange() def _contains_value(self, value: _T) -> bool: @@ -856,7 +856,7 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): return self.expr.operate(operators.mul, other) -class AbstractRangeImpl(AbstractRange[Range[_T]]): +class AbstractRangeImpl(AbstractRange[_T]): """Marker for AbstractRange that will apply a subclass-specific adaptation""" @@ -867,80 +867,78 @@ class AbstractMultiRange(AbstractRange[Range[_T]]): __abstract__ = True -class AbstractMultiRangeImpl( - AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]] -): +class AbstractMultiRangeImpl(AbstractRangeImpl[_T], AbstractMultiRange[_T]): """Marker for AbstractRange that will apply a subclass-specific adaptation""" -class INT4RANGE(AbstractRange[Range[int]]): +class INT4RANGE(AbstractRange[int]): """Represent the PostgreSQL INT4RANGE type.""" __visit_name__ = "INT4RANGE" -class INT8RANGE(AbstractRange[Range[int]]): +class INT8RANGE(AbstractRange[int]): """Represent the PostgreSQL INT8RANGE type.""" __visit_name__ = "INT8RANGE" -class NUMRANGE(AbstractRange[Range[Decimal]]): +class NUMRANGE(AbstractRange[Decimal]): """Represent the PostgreSQL NUMRANGE type.""" __visit_name__ = "NUMRANGE" -class DATERANGE(AbstractRange[Range[date]]): +class DATERANGE(AbstractRange[date]): """Represent the PostgreSQL DATERANGE type.""" __visit_name__ = "DATERANGE" -class TSRANGE(AbstractRange[Range[datetime]]): +class TSRANGE(AbstractRange[datetime]): """Represent the PostgreSQL TSRANGE type.""" __visit_name__ = "TSRANGE" -class TSTZRANGE(AbstractRange[Range[datetime]]): +class TSTZRANGE(AbstractRange[datetime]): """Represent the PostgreSQL TSTZRANGE type.""" __visit_name__ = "TSTZRANGE" -class INT4MULTIRANGE(AbstractMultiRange[Range[int]]): +class INT4MULTIRANGE(AbstractMultiRange[int]): """Represent the PostgreSQL INT4MULTIRANGE type.""" __visit_name__ = "INT4MULTIRANGE" -class INT8MULTIRANGE(AbstractMultiRange[Range[int]]): +class INT8MULTIRANGE(AbstractMultiRange[int]): """Represent the PostgreSQL INT8MULTIRANGE type.""" __visit_name__ = "INT8MULTIRANGE" -class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]): +class NUMMULTIRANGE(AbstractMultiRange[Decimal]): """Represent the PostgreSQL NUMMULTIRANGE type.""" __visit_name__ = "NUMMULTIRANGE" -class DATEMULTIRANGE(AbstractMultiRange[Range[date]]): +class DATEMULTIRANGE(AbstractMultiRange[date]): """Represent the PostgreSQL DATEMULTIRANGE type.""" __visit_name__ = "DATEMULTIRANGE" -class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]): +class TSMULTIRANGE(AbstractMultiRange[datetime]): """Represent the PostgreSQL TSRANGE type.""" __visit_name__ = "TSMULTIRANGE" -class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]): +class TSTZMULTIRANGE(AbstractMultiRange[datetime]): """Represent the PostgreSQL TSTZRANGE type.""" __visit_name__ = "TSTZMULTIRANGE" diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index 4567daa386..0a061fd736 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -12,14 +12,16 @@ from sqlalchemy import Text from sqlalchemy import UniqueConstraint from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import array +from sqlalchemy.dialects.postgresql import DATERANGE from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.dialects.postgresql import INT4RANGE +from sqlalchemy.dialects.postgresql import INT8RANGE from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column - # test #6402 c1 = Column(UUID()) @@ -77,3 +79,11 @@ insert(Test).on_conflict_do_nothing( ).on_conflict_do_update( unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22 ).excluded.foo.desc() + + +# EXPECTED_TYPE: Column[Range[int]] +reveal_type(Column(INT4RANGE())) +# EXPECTED_TYPE: Column[Range[int]] +reveal_type(Column(INT8RANGE())) +# EXPECTED_TYPE: Column[Range[datetime.date]] +reveal_type(Column(DATERANGE()))