]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix typing generics in PostgreSQL range types. 10625/head
authorJim Bosch <jbosch@astro.princeton.edu>
Mon, 13 Nov 2023 15:26:54 +0000 (10:26 -0500)
committerJim Bosch <jbosch@astro.princeton.edu>
Tue, 14 Nov 2023 18:26:43 +0000 (13:26 -0500)
AbstractRange is parameterized on the bounds type of the range, not
the range type; this led to e.g. Range[Range[int]] appearing as the
type of some expressions that were actually Range[int].

lib/sqlalchemy/dialects/postgresql/ranges.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py

index f1c29897d013a371431868dd187266900d8e98aa..15728abada9c9bcc25e44d55b7d46d461ac033c0 100644 (file)
@@ -151,7 +151,7 @@ class Range(Generic[_T]):
         return not self.empty and self.upper is None
 
     @property
-    def __sa_type_engine__(self) -> AbstractRange[Range[_T]]:
+    def __sa_type_engine__(self) -> AbstractRange[_T]:
         return AbstractRange()
 
     def _contains_value(self, value: _T) -> bool:
@@ -856,7 +856,7 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
             return self.expr.operate(operators.mul, other)
 
 
-class AbstractRangeImpl(AbstractRange[Range[_T]]):
+class AbstractRangeImpl(AbstractRange[_T]):
     """Marker for AbstractRange that will apply a subclass-specific
     adaptation"""
 
@@ -867,80 +867,78 @@ class AbstractMultiRange(AbstractRange[Range[_T]]):
     __abstract__ = True
 
 
-class AbstractMultiRangeImpl(
-    AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]]
-):
+class AbstractMultiRangeImpl(AbstractRangeImpl[_T], AbstractMultiRange[_T]):
     """Marker for AbstractRange that will apply a subclass-specific
     adaptation"""
 
 
-class INT4RANGE(AbstractRange[Range[int]]):
+class INT4RANGE(AbstractRange[int]):
     """Represent the PostgreSQL INT4RANGE type."""
 
     __visit_name__ = "INT4RANGE"
 
 
-class INT8RANGE(AbstractRange[Range[int]]):
+class INT8RANGE(AbstractRange[int]):
     """Represent the PostgreSQL INT8RANGE type."""
 
     __visit_name__ = "INT8RANGE"
 
 
-class NUMRANGE(AbstractRange[Range[Decimal]]):
+class NUMRANGE(AbstractRange[Decimal]):
     """Represent the PostgreSQL NUMRANGE type."""
 
     __visit_name__ = "NUMRANGE"
 
 
-class DATERANGE(AbstractRange[Range[date]]):
+class DATERANGE(AbstractRange[date]):
     """Represent the PostgreSQL DATERANGE type."""
 
     __visit_name__ = "DATERANGE"
 
 
-class TSRANGE(AbstractRange[Range[datetime]]):
+class TSRANGE(AbstractRange[datetime]):
     """Represent the PostgreSQL TSRANGE type."""
 
     __visit_name__ = "TSRANGE"
 
 
-class TSTZRANGE(AbstractRange[Range[datetime]]):
+class TSTZRANGE(AbstractRange[datetime]):
     """Represent the PostgreSQL TSTZRANGE type."""
 
     __visit_name__ = "TSTZRANGE"
 
 
-class INT4MULTIRANGE(AbstractMultiRange[Range[int]]):
+class INT4MULTIRANGE(AbstractMultiRange[int]):
     """Represent the PostgreSQL INT4MULTIRANGE type."""
 
     __visit_name__ = "INT4MULTIRANGE"
 
 
-class INT8MULTIRANGE(AbstractMultiRange[Range[int]]):
+class INT8MULTIRANGE(AbstractMultiRange[int]):
     """Represent the PostgreSQL INT8MULTIRANGE type."""
 
     __visit_name__ = "INT8MULTIRANGE"
 
 
-class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]):
+class NUMMULTIRANGE(AbstractMultiRange[Decimal]):
     """Represent the PostgreSQL NUMMULTIRANGE type."""
 
     __visit_name__ = "NUMMULTIRANGE"
 
 
-class DATEMULTIRANGE(AbstractMultiRange[Range[date]]):
+class DATEMULTIRANGE(AbstractMultiRange[date]):
     """Represent the PostgreSQL DATEMULTIRANGE type."""
 
     __visit_name__ = "DATEMULTIRANGE"
 
 
-class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]):
+class TSMULTIRANGE(AbstractMultiRange[datetime]):
     """Represent the PostgreSQL TSRANGE type."""
 
     __visit_name__ = "TSMULTIRANGE"
 
 
-class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]):
+class TSTZMULTIRANGE(AbstractMultiRange[datetime]):
     """Represent the PostgreSQL TSTZRANGE type."""
 
     __visit_name__ = "TSTZMULTIRANGE"
index 4567daa38665ddafe36597d0cd283ac5ae067875..0a061fd736c581944776dbc8f553cb794602d10a 100644 (file)
@@ -12,14 +12,16 @@ from sqlalchemy import Text
 from sqlalchemy import UniqueConstraint
 from sqlalchemy.dialects.postgresql import ARRAY
 from sqlalchemy.dialects.postgresql import array
+from sqlalchemy.dialects.postgresql import DATERANGE
 from sqlalchemy.dialects.postgresql import insert
+from sqlalchemy.dialects.postgresql import INT4RANGE
+from sqlalchemy.dialects.postgresql import INT8RANGE
 from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.dialects.postgresql import UUID
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 
-
 # test #6402
 
 c1 = Column(UUID())
@@ -77,3 +79,11 @@ insert(Test).on_conflict_do_nothing(
 ).on_conflict_do_update(
     unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22
 ).excluded.foo.desc()
+
+
+# EXPECTED_TYPE: Column[Range[int]]
+reveal_type(Column(INT4RANGE()))
+# EXPECTED_TYPE: Column[Range[int]]
+reveal_type(Column(INT8RANGE()))
+# EXPECTED_TYPE: Column[Range[datetime.date]]
+reveal_type(Column(DATERANGE()))