]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add typing for PG UUID, other types
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 May 2022 13:56:01 +0000 (09:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 27 May 2022 13:56:01 +0000 (09:56 -0400)
note that UUID will be generalized into core with #7212.

Fixes: #6402
Change-Id: I90f0052ca74367c2c2f1ce2f8a90e81d173d1430

lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/sql/sqltypes.py
test/ext/mypy/plain_files/pg_stuff.py

index 298485f40e739b3127e3403bd955432c4a215f4a..3b5eaed30e3c9107ee3b95b0badb480429fff362 100644 (file)
@@ -11,12 +11,14 @@ from __future__ import annotations
 
 import re
 from typing import Any
+from typing import Optional
 from typing import TypeVar
 
 from ... import types as sqltypes
 from ... import util
 from ...sql import expression
 from ...sql import operators
+from ...sql._typing import _TypeEngineArgument
 
 
 _T = TypeVar("_T", bound=Any)
@@ -244,7 +246,11 @@ class ARRAY(sqltypes.ARRAY):
     comparator_factory = Comparator
 
     def __init__(
-        self, item_type, as_tuple=False, dimensions=None, zero_indexes=False
+        self,
+        item_type: _TypeEngineArgument[Any],
+        as_tuple: bool = False,
+        dimensions: Optional[int] = None,
+        zero_indexes: bool = False,
     ):
         """Construct an ARRAY.
 
index 6a49e296cae11902e9546f5adb7d6756fab72b23..0aeeb806ba01ed651510b9d3d958144861bb57e4 100644 (file)
@@ -6,7 +6,6 @@
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: ignore-errors
 
-
 r"""
 .. dialect:: postgresql
     :name: PostgreSQL
@@ -1448,9 +1447,14 @@ E.g.::
 
 """  # noqa: E501
 
+from __future__ import annotations
+
 from collections import defaultdict
 import datetime as dt
 import re
+from typing import Any
+from typing import overload
+from typing import TypeVar
 from uuid import UUID as _python_UUID
 
 from . import array as _array
@@ -1486,6 +1490,7 @@ from ...types import REAL
 from ...types import SMALLINT
 from ...types import TEXT
 from ...types import VARCHAR
+from ...util.typing import Literal
 
 IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I)
 
@@ -1601,32 +1606,32 @@ _FLOAT_TYPES = (700, 701, 1021, 1022)
 _INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
 
 
-class BYTEA(sqltypes.LargeBinary):
+class BYTEA(sqltypes.LargeBinary[bytes]):
     __visit_name__ = "BYTEA"
 
 
-class INET(sqltypes.TypeEngine):
+class INET(sqltypes.TypeEngine[str]):
     __visit_name__ = "INET"
 
 
 PGInet = INET
 
 
-class CIDR(sqltypes.TypeEngine):
+class CIDR(sqltypes.TypeEngine[str]):
     __visit_name__ = "CIDR"
 
 
 PGCidr = CIDR
 
 
-class MACADDR(sqltypes.TypeEngine):
+class MACADDR(sqltypes.TypeEngine[str]):
     __visit_name__ = "MACADDR"
 
 
 PGMacAddr = MACADDR
 
 
-class MONEY(sqltypes.TypeEngine):
+class MONEY(sqltypes.TypeEngine[str]):
 
     r"""Provide the PostgreSQL MONEY type.
 
@@ -1671,7 +1676,7 @@ class MONEY(sqltypes.TypeEngine):
     __visit_name__ = "MONEY"
 
 
-class OID(sqltypes.TypeEngine):
+class OID(sqltypes.TypeEngine[int]):
 
     """Provide the PostgreSQL OID type.
 
@@ -1682,7 +1687,7 @@ class OID(sqltypes.TypeEngine):
     __visit_name__ = "OID"
 
 
-class REGCLASS(sqltypes.TypeEngine):
+class REGCLASS(sqltypes.TypeEngine[str]):
 
     """Provide the PostgreSQL REGCLASS type.
 
@@ -1745,7 +1750,7 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
 PGInterval = INTERVAL
 
 
-class BIT(sqltypes.TypeEngine):
+class BIT(sqltypes.TypeEngine[int]):
     __visit_name__ = "BIT"
 
     def __init__(self, length=None, varying=False):
@@ -1760,8 +1765,10 @@ class BIT(sqltypes.TypeEngine):
 
 PGBit = BIT
 
+_UUID_RETURN = TypeVar("_UUID_RETURN", str, _python_UUID)
+
 
-class UUID(sqltypes.TypeEngine):
+class UUID(sqltypes.TypeEngine[_UUID_RETURN]):
 
     """PostgreSQL UUID type.
 
@@ -1777,7 +1784,15 @@ class UUID(sqltypes.TypeEngine):
 
     __visit_name__ = "UUID"
 
-    def __init__(self, as_uuid=True):
+    @overload
+    def __init__(self: "UUID[_python_UUID]", as_uuid: Literal[True] = ...):
+        ...
+
+    @overload
+    def __init__(self: "UUID[str]", as_uuid: Literal[False] = ...):
+        ...
+
+    def __init__(self, as_uuid: bool = True):
         """Construct a UUID type.
 
 
@@ -1848,7 +1863,7 @@ class UUID(sqltypes.TypeEngine):
 PGUuid = UUID
 
 
-class TSVECTOR(sqltypes.TypeEngine):
+class TSVECTOR(sqltypes.TypeEngine[Any]):
 
     """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
     text search type TSVECTOR.
index 4b0b408f7f239152edabb6c3c3f32d74435d8769..90b4b9c9e1b861ed414cb22b36bdebf575210751 100644 (file)
@@ -46,8 +46,8 @@ from .elements import TypeCoerce as type_coerce  # noqa
 from .type_api import Emulated
 from .type_api import NativeForEmulated  # noqa
 from .type_api import to_instance
-from .type_api import TypeDecorator
-from .type_api import TypeEngine
+from .type_api import TypeDecorator as TypeDecorator
+from .type_api import TypeEngine as TypeEngine
 from .type_api import TypeEngineMixin
 from .type_api import Variant  # noqa
 from .visitors import InternalTraversal
index ce0272397265e9b50d3df05964785aed1ad5d72f..c90bb67f0e89f388dcc253259ba128a515141e0d 100644 (file)
@@ -1,3 +1,7 @@
+from typing import Any
+from typing import Dict
+from uuid import UUID as _py_uuid
+
 from sqlalchemy import cast
 from sqlalchemy import Column
 from sqlalchemy import func
@@ -8,7 +12,23 @@ from sqlalchemy import Text
 from sqlalchemy.dialects.postgresql import ARRAY
 from sqlalchemy.dialects.postgresql import array
 from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.postgresql import UUID
 from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+# test #6402
+
+c1 = Column(UUID())
+
+# EXPECTED_TYPE: Column[UUID]
+reveal_type(c1)
+
+c2 = Column(UUID(as_uuid=False))
+
+# EXPECTED_TYPE: Column[str]
+reveal_type(c2)
 
 
 class Base(DeclarativeBase):
@@ -18,8 +38,12 @@ class Base(DeclarativeBase):
 class Test(Base):
     __tablename__ = "test_table_json"
 
-    id = Column(Integer, primary_key=True)
-    data = Column(JSONB)
+    id = mapped_column(Integer, primary_key=True)
+    data: Mapped[Dict[str, Any]] = mapped_column(JSONB)
+
+    ident: Mapped[_py_uuid] = mapped_column(UUID())
+
+    ident_str: Mapped[str] = mapped_column(UUID(as_uuid=False))
 
 
 elem = func.jsonb_array_elements(Test.data, type_=JSONB).column_valued("elem")
@@ -35,3 +59,12 @@ stmt = select(Test).where(
     )
 )
 print(stmt)
+
+
+t1 = Test()
+
+# EXPECTED_RE_TYPE: .*[dD]ict\[.*str, Any\]
+reveal_type(t1.data)
+
+# EXPECTED_TYPE: UUID
+reveal_type(t1.ident)