]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Make ARRAY generic on the item_type
authorDenis Laxalde <denis@laxalde.org>
Tue, 18 Mar 2025 16:23:01 +0000 (12:23 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 18 Mar 2025 21:16:01 +0000 (22:16 +0100)
Now `Column(type_=ARRAY(Integer)` is inferred as `Column[Sequence[int]]` instead as `Column[Sequence[Any]]` previously. This only works with the `type_` argument to Column, but that's not new.

This follows from a suggestion at
https://github.com/sqlalchemy/sqlalchemy/pull/12386#issuecomment-2694056069.

Related to #6810.

Closes: #12443
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12443
Pull-request-sha: 2fff4e89cd0b72d9444ce3f3d845b152770fc55d

Change-Id: I87b828fd82d10fbf157141db3c31f0ec8149caad
(cherry picked from commit 500adfafcb782c5b22ff49e00192a2ed42ed09b6)

lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/sql/sqltypes.py
test/typing/plain_files/dialects/postgresql/pg_stuff.py

index 8cbe0c48cf9a617e5252fe1c29c73545d9087f30..54590ad766046ad604d7c25d56ce4c8de3a84a49 100644 (file)
@@ -199,7 +199,7 @@ class array(expression.ExpressionClauseList[_T]):
             return self
 
 
-class ARRAY(sqltypes.ARRAY):
+class ARRAY(sqltypes.ARRAY[_T]):
     """PostgreSQL ARRAY type.
 
     The :class:`_postgresql.ARRAY` type is constructed in the same way
@@ -273,7 +273,7 @@ class ARRAY(sqltypes.ARRAY):
 
     def __init__(
         self,
-        item_type: _TypeEngineArgument[typing_Any],
+        item_type: _TypeEngineArgument[_T],
         as_tuple: bool = False,
         dimensions: Optional[int] = None,
         zero_indexes: bool = False,
@@ -322,7 +322,7 @@ class ARRAY(sqltypes.ARRAY):
         self.dimensions = dimensions
         self.zero_indexes = zero_indexes
 
-    class Comparator(sqltypes.ARRAY.Comparator):
+    class Comparator(sqltypes.ARRAY.Comparator[_T]):
         """Define comparison operations for :class:`_types.ARRAY`.
 
         Note that these operations are in addition to those provided
@@ -363,7 +363,7 @@ class ARRAY(sqltypes.ARRAY):
     def _against_native_enum(self) -> bool:
         return (
             isinstance(self.item_type, sqltypes.Enum)
-            and self.item_type.native_enum
+            and self.item_type.native_enum  # type: ignore[attr-defined]
         )
 
     def literal_processor(
index 1c316eecf62c2ed24e0d7194d9eb8543613a0972..d0d89e731687a90d6fb3fb66151fbcb86fbe3153 100644 (file)
@@ -2801,7 +2801,7 @@ class JSON(Indexable, TypeEngine[Any]):
 
 
 class ARRAY(
-    SchemaEventTarget, Indexable, Concatenable, TypeEngine[Sequence[Any]]
+    SchemaEventTarget, Indexable, Concatenable, TypeEngine[Sequence[_T]]
 ):
     """Represent a SQL Array type.
 
@@ -2924,7 +2924,7 @@ class ARRAY(
 
     def __init__(
         self,
-        item_type: _TypeEngineArgument[Any],
+        item_type: _TypeEngineArgument[_T],
         as_tuple: bool = False,
         dimensions: Optional[int] = None,
         zero_indexes: bool = False,
@@ -2973,8 +2973,8 @@ class ARRAY(
         self.zero_indexes = zero_indexes
 
     class Comparator(
-        Indexable.Comparator[Sequence[Any]],
-        Concatenable.Comparator[Sequence[Any]],
+        Indexable.Comparator[Sequence[_T]],
+        Concatenable.Comparator[Sequence[_T]],
     ):
         """Define comparison operations for :class:`_types.ARRAY`.
 
@@ -2985,7 +2985,7 @@ class ARRAY(
 
         __slots__ = ()
 
-        type: ARRAY
+        type: ARRAY[_T]
 
         @overload
         def _setup_getitem(
index 45ec981bbae2758975f40afadcaa792f092c4627..bc05ef8c4418d5b99a23e61a4b544751fdee57da 100644 (file)
@@ -117,3 +117,9 @@ reveal_type(array_of_ints)
 
 # EXPECTED_MYPY: Cannot infer type argument 1 of "array"
 array([0], type_=Text)
+
+# EXPECTED_TYPE: ARRAY[str]
+reveal_type(ARRAY(Text))
+
+# EXPECTED_TYPE: Column[Sequence[int]]
+reveal_type(Column(type_=ARRAY(Integer)))