From 21ae13765d7410228672a282fef29fc0e2b3b098 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 27 May 2022 09:56:01 -0400 Subject: [PATCH] add typing for PG UUID, other types note that UUID will be generalized into core with #7212. Fixes: #6402 Change-Id: I90f0052ca74367c2c2f1ce2f8a90e81d173d1430 --- lib/sqlalchemy/dialects/postgresql/array.py | 8 ++++- lib/sqlalchemy/dialects/postgresql/base.py | 39 ++++++++++++++------- lib/sqlalchemy/sql/sqltypes.py | 4 +-- test/ext/mypy/plain_files/pg_stuff.py | 37 +++++++++++++++++-- 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 298485f40e..3b5eaed30e 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -11,12 +11,14 @@ from __future__ import annotations import re from typing import Any +from typing import Optional from typing import TypeVar from ... import types as sqltypes from ... import util from ...sql import expression from ...sql import operators +from ...sql._typing import _TypeEngineArgument _T = TypeVar("_T", bound=Any) @@ -244,7 +246,11 @@ class ARRAY(sqltypes.ARRAY): comparator_factory = Comparator def __init__( - self, item_type, as_tuple=False, dimensions=None, zero_indexes=False + self, + item_type: _TypeEngineArgument[Any], + as_tuple: bool = False, + dimensions: Optional[int] = None, + zero_indexes: bool = False, ): """Construct an ARRAY. diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 6a49e296ca..0aeeb806ba 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -6,7 +6,6 @@ # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors - r""" .. dialect:: postgresql :name: PostgreSQL @@ -1448,9 +1447,14 @@ E.g.:: """ # noqa: E501 +from __future__ import annotations + from collections import defaultdict import datetime as dt import re +from typing import Any +from typing import overload +from typing import TypeVar from uuid import UUID as _python_UUID from . import array as _array @@ -1486,6 +1490,7 @@ from ...types import REAL from ...types import SMALLINT from ...types import TEXT from ...types import VARCHAR +from ...util.typing import Literal IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I) @@ -1601,32 +1606,32 @@ _FLOAT_TYPES = (700, 701, 1021, 1022) _INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) -class BYTEA(sqltypes.LargeBinary): +class BYTEA(sqltypes.LargeBinary[bytes]): __visit_name__ = "BYTEA" -class INET(sqltypes.TypeEngine): +class INET(sqltypes.TypeEngine[str]): __visit_name__ = "INET" PGInet = INET -class CIDR(sqltypes.TypeEngine): +class CIDR(sqltypes.TypeEngine[str]): __visit_name__ = "CIDR" PGCidr = CIDR -class MACADDR(sqltypes.TypeEngine): +class MACADDR(sqltypes.TypeEngine[str]): __visit_name__ = "MACADDR" PGMacAddr = MACADDR -class MONEY(sqltypes.TypeEngine): +class MONEY(sqltypes.TypeEngine[str]): r"""Provide the PostgreSQL MONEY type. @@ -1671,7 +1676,7 @@ class MONEY(sqltypes.TypeEngine): __visit_name__ = "MONEY" -class OID(sqltypes.TypeEngine): +class OID(sqltypes.TypeEngine[int]): """Provide the PostgreSQL OID type. @@ -1682,7 +1687,7 @@ class OID(sqltypes.TypeEngine): __visit_name__ = "OID" -class REGCLASS(sqltypes.TypeEngine): +class REGCLASS(sqltypes.TypeEngine[str]): """Provide the PostgreSQL REGCLASS type. @@ -1745,7 +1750,7 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): PGInterval = INTERVAL -class BIT(sqltypes.TypeEngine): +class BIT(sqltypes.TypeEngine[int]): __visit_name__ = "BIT" def __init__(self, length=None, varying=False): @@ -1760,8 +1765,10 @@ class BIT(sqltypes.TypeEngine): PGBit = BIT +_UUID_RETURN = TypeVar("_UUID_RETURN", str, _python_UUID) + -class UUID(sqltypes.TypeEngine): +class UUID(sqltypes.TypeEngine[_UUID_RETURN]): """PostgreSQL UUID type. @@ -1777,7 +1784,15 @@ class UUID(sqltypes.TypeEngine): __visit_name__ = "UUID" - def __init__(self, as_uuid=True): + @overload + def __init__(self: "UUID[_python_UUID]", as_uuid: Literal[True] = ...): + ... + + @overload + def __init__(self: "UUID[str]", as_uuid: Literal[False] = ...): + ... + + def __init__(self, as_uuid: bool = True): """Construct a UUID type. @@ -1848,7 +1863,7 @@ class UUID(sqltypes.TypeEngine): PGUuid = UUID -class TSVECTOR(sqltypes.TypeEngine): +class TSVECTOR(sqltypes.TypeEngine[Any]): """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL text search type TSVECTOR. diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 4b0b408f7f..90b4b9c9e1 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -46,8 +46,8 @@ from .elements import TypeCoerce as type_coerce # noqa from .type_api import Emulated from .type_api import NativeForEmulated # noqa from .type_api import to_instance -from .type_api import TypeDecorator -from .type_api import TypeEngine +from .type_api import TypeDecorator as TypeDecorator +from .type_api import TypeEngine as TypeEngine from .type_api import TypeEngineMixin from .type_api import Variant # noqa from .visitors import InternalTraversal diff --git a/test/ext/mypy/plain_files/pg_stuff.py b/test/ext/mypy/plain_files/pg_stuff.py index ce02723972..c90bb67f0e 100644 --- a/test/ext/mypy/plain_files/pg_stuff.py +++ b/test/ext/mypy/plain_files/pg_stuff.py @@ -1,3 +1,7 @@ +from typing import Any +from typing import Dict +from uuid import UUID as _py_uuid + from sqlalchemy import cast from sqlalchemy import Column from sqlalchemy import func @@ -8,7 +12,23 @@ from sqlalchemy import Text from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import array from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +# test #6402 + +c1 = Column(UUID()) + +# EXPECTED_TYPE: Column[UUID] +reveal_type(c1) + +c2 = Column(UUID(as_uuid=False)) + +# EXPECTED_TYPE: Column[str] +reveal_type(c2) class Base(DeclarativeBase): @@ -18,8 +38,12 @@ class Base(DeclarativeBase): class Test(Base): __tablename__ = "test_table_json" - id = Column(Integer, primary_key=True) - data = Column(JSONB) + id = mapped_column(Integer, primary_key=True) + data: Mapped[Dict[str, Any]] = mapped_column(JSONB) + + ident: Mapped[_py_uuid] = mapped_column(UUID()) + + ident_str: Mapped[str] = mapped_column(UUID(as_uuid=False)) elem = func.jsonb_array_elements(Test.data, type_=JSONB).column_valued("elem") @@ -35,3 +59,12 @@ stmt = select(Test).where( ) ) print(stmt) + + +t1 = Test() + +# EXPECTED_RE_TYPE: .*[dD]ict\[.*str, Any\] +reveal_type(t1.data) + +# EXPECTED_TYPE: UUID +reveal_type(t1.ident) -- 2.47.2