From: Federico Caselli Date: Tue, 19 Nov 2024 22:12:51 +0000 (+0100) Subject: General improvement on annotated declarative X-Git-Tag: rel_2_0_37~8^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=43d29b9695eb8229c70fafe87616ccc9ad969b3f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git General improvement on annotated declarative Fix issue that resulted in inconsistent handing of unions depending on how they were declared Consistently support TypeAliasType. This has required a revision of the implementation added in #11305 to have a consistent behavior. References: #11944 References: #11955 References: #11305 Change-Id: Iffc34fd42b9769f73ddb4331bd59b6b37391635d (cherry picked from commit e6b0b421d60ecf660cf3872f3f32dd2b7a739b59) --- diff --git a/doc/build/changelog/unreleased_20/11944.rst b/doc/build/changelog/unreleased_20/11944.rst new file mode 100644 index 0000000000..e7469180ec --- /dev/null +++ b/doc/build/changelog/unreleased_20/11944.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: bug, orm + :tickets: 11944 + + Fixed bug in how type unions were handled that made the behavior + of ``a | b`` different from ``Union[a, b]``. diff --git a/doc/build/changelog/unreleased_20/11955.rst b/doc/build/changelog/unreleased_20/11955.rst new file mode 100644 index 0000000000..eeeb2bcbdd --- /dev/null +++ b/doc/build/changelog/unreleased_20/11955.rst @@ -0,0 +1,13 @@ +.. change:: + :tags: bug, orm + :tickets: 11955 + + Consistently handle ``TypeAliasType`` (defined in PEP 695) obtained with the + ``type X = int`` syntax introduced in python 3.12. + Now in all cases one such alias must be explicitly added to the type map for + it to be usable inside ``Mapped[]``. + This change also revises the approach added in :ticket:`11305`, now requiring + the ``TypeAliasType`` to be added to the type map. + Documentation on how unions and type alias types are handled by SQLAlchemy + has been added in the :ref:`orm_declarative_mapped_column_type_map` section + of the documentation. diff --git a/doc/build/orm/declarative_tables.rst b/doc/build/orm/declarative_tables.rst index b2c91981b3..4bb4237ac1 100644 --- a/doc/build/orm/declarative_tables.rst +++ b/doc/build/orm/declarative_tables.rst @@ -316,9 +316,8 @@ the registry and Declarative base could be configured as:: import datetime - from sqlalchemy import BIGINT, Integer, NVARCHAR, String, TIMESTAMP - from sqlalchemy.orm import DeclarativeBase - from sqlalchemy.orm import Mapped, mapped_column, registry + from sqlalchemy import BIGINT, NVARCHAR, String, TIMESTAMP + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column class Base(DeclarativeBase): @@ -369,6 +368,59 @@ while still being able to use succinct annotation-only :func:`_orm.mapped_column configurations. There are two more levels of Python-type configurability available beyond this, described in the next two sections. +Union types inside the Type Map +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +SQLAlchemy supports mapping union types inside the type map to allow +mapping database types that can support multiple Python types, +such as :class:`_types.JSON` or :class:`_postgresql.JSONB`:: + + from sqlalchemy import JSON + from sqlalchemy.dialects import postgresql + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + from sqlalchemy.schema import CreateTable + + json_list = list[int] | list[str] + json_scalar = float | str | bool | None + + + class Base(DeclarativeBase): + type_annotation_map = { + json_list: postgresql.JSONB, + json_scalar: JSON, + } + + + class SomeClass(Base): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + list_col: Mapped[list[str] | list[int]] + scalar_col: Mapped[json_scalar] + scalar_col_not_null: Mapped[str | float | bool] + +Using the union directly inside ``Mapped`` or creating a new one with the same +effective types has the same behavior: ``list_col`` will be matched to the +``json_list`` union even if it does not reference it directly (the order of the +types also does not matter). +If the union added to the type map includes ``None``, it will be ignored +when matching the ``Mapped`` type since ``None`` is only used to decide +the column nullability. It follows that both ``scalar_col`` and +``scalar_col_not_null`` will match the ``json_scalar`` union. + +The CREATE TABLE statement of the table created above is as follows: + +.. sourcecode:: pycon+sql + + >>> print(CreateTable(SomeClass.__table__).compile(dialect=postgresql.dialect())) + {printsql}CREATE TABLE some_table ( + id SERIAL NOT NULL, + list_col JSONB NOT NULL, + scalar_col JSON, + scalar_col_not_null JSON NOT NULL, + PRIMARY KEY (id) + ) + .. _orm_declarative_mapped_column_type_map_pep593: Mapping Multiple Type Configurations to Python Types @@ -458,6 +510,96 @@ us a wide degree of flexibility, the next section illustrates a second way in which ``Annotated`` may be used with Declarative that is even more open ended. +Support for Type Alias Types (defined by PEP 695) and NewType +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The typing module allows an user to create "new types" using ``typing.NewType``:: + + from typing import NewType + + nstr30 = NewType("nstr30", str) + nstr50 = NewType("nstr50", str) + +These are considered as different by the type checkers and by python:: + + >>> print(str == nstr30, nstr50 == nstr30, nstr30 == NewType("nstr30", str)) + False False False + +Another similar feature was added in Python 3.12 to create aliases, +using a new syntax to define ``typing.TypeAliasType``:: + + type SmallInt = int + type BigInt = int + type JsonScalar = str | float | bool | None + +Like ``typing.NewType``, these are treated by python as different, meaning that they are +not equal between each other even if they represent the same Python type. +In the example above, ``SmallInt`` and ``BigInt`` are not considered equal even +if they both are aliases of the python type ``int``:: + + >>> print(SmallInt == BigInt) + False + +SQLAlchemy supports using ``typing.NewType`` and ``typing.TypeAliasType`` +in the ``type_annotation_map``. They can be used to associate the same python type +to different :class:`_types.TypeEngine` types, similarly +to ``typing.Annotated``:: + + from sqlalchemy import SmallInteger, BigInteger, JSON, String + from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + from sqlalchemy.schema import CreateTable + + + class TABase(DeclarativeBase): + type_annotation_map = { + nstr30: String(30), + nstr50: String(50), + SmallInt: SmallInteger, + BigInteger: BigInteger, + JsonScalar: JSON, + } + + + class SomeClass(TABase): + __tablename__ = "some_table" + + id: Mapped[int] = mapped_column(primary_key=True) + normal_str: Mapped[str] + + short_str: Mapped[nstr30] + long_str: Mapped[nstr50] + + small_int: Mapped[SmallInt] + big_int: Mapped[BigInteger] + scalar_col: Mapped[JsonScalar] + +a CREATE TABLE for the above mapping will illustrate the different variants +of integer and string we've configured, and looks like: + +.. sourcecode:: pycon+sql + + >>> print(CreateTable(SomeClass.__table__)) + {printsql}CREATE TABLE some_table ( + id INTEGER NOT NULL, + normal_str VARCHAR NOT NULL, + short_str VARCHAR(30) NOT NULL, + long_str VARCHAR(50) NOT NULL, + small_int SMALLINT NOT NULL, + big_int BIGINT NOT NULL, + scalar_col JSON, + PRIMARY KEY (id) + ) + +Since the ``JsonScalar`` type includes ``None`` the columns is nullable, while +``id`` and ``normal_str`` columns use the default mapping for their respective +Python type. + +As mentioned above, since ``typing.NewType`` and ``typing.TypeAliasType`` are +considered standalone types, they must be referenced directly inside ``Mapped`` +and must be added explicitly to the type map. +Failing to do so will raise an error since SQLAlchemy does not know what +SQL type to use. + .. _orm_declarative_mapped_column_pep593: Mapping Whole Column Declarations to Python Types @@ -743,6 +885,28 @@ appropriate settings, including default string length. If a ``typing.Literal`` that does not consist of only string values is passed, an informative error is raised. +``typing.TypeAliasType`` can also be used to create enums, by assigning them +to a ``typing.Literal`` of strings:: + + from typing import Literal + + type Status = Literal["on", "off", "unknown"] + +Since this is a ``typing.TypeAliasType``, it represents a unique type object, +so it must be placed in the ``type_annotation_map`` for it to be looked up +successfully, keyed to the :class:`.Enum` type as follows:: + + import enum + import sqlalchemy + + + class Base(DeclarativeBase): + type_annotation_map = {Status: sqlalchemy.Enum(enum.Enum)} + +Since SQLAlchemy supports mapping different ``typing.TypeAliasType`` +objects that are otherwise structurally equivalent individually, +these must be present in ``type_annotation_map`` to avoid ambiguity. + Native Enums and Naming +++++++++++++++++++++++ diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 718cf72516..a3b0ac21f0 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -14,7 +14,6 @@ import re import typing from typing import Any from typing import Callable -from typing import cast from typing import ClassVar from typing import Dict from typing import FrozenSet @@ -72,6 +71,7 @@ from ..sql.selectable import FromClause from ..util import hybridmethod from ..util import hybridproperty from ..util import typing as compat_typing +from ..util import warn_deprecated from ..util.typing import CallableReference from ..util.typing import de_optionalize_union_types from ..util.typing import flatten_newtype @@ -80,6 +80,7 @@ from ..util.typing import is_literal from ..util.typing import is_newtype from ..util.typing import is_pep695 from ..util.typing import Literal +from ..util.typing import LITERAL_TYPES from ..util.typing import Self if TYPE_CHECKING: @@ -1232,40 +1233,27 @@ class registry: ) def _resolve_type( - self, python_type: _MatchedOnType + self, python_type: _MatchedOnType, _do_fallbacks: bool = True ) -> Optional[sqltypes.TypeEngine[Any]]: - - python_type_to_check = python_type - while is_pep695(python_type_to_check): - python_type_to_check = python_type_to_check.__value__ - - check_is_pt = python_type is python_type_to_check - python_type_type: Type[Any] search: Iterable[Tuple[_MatchedOnType, Type[Any]]] - if is_generic(python_type_to_check): - if is_literal(python_type_to_check): - python_type_type = cast("Type[Any]", python_type_to_check) + if is_generic(python_type): + if is_literal(python_type): + python_type_type = python_type # type: ignore[assignment] - search = ( # type: ignore[assignment] + search = ( (python_type, python_type_type), - (Literal, python_type_type), + *((lt, python_type_type) for lt in LITERAL_TYPES), # type: ignore[arg-type] # noqa: E501 ) else: - python_type_type = python_type_to_check.__origin__ + python_type_type = python_type.__origin__ search = ((python_type, python_type_type),) - elif is_newtype(python_type_to_check): - python_type_type = flatten_newtype(python_type_to_check) - search = ((python_type, python_type_type),) - elif isinstance(python_type_to_check, type): - python_type_type = python_type_to_check - search = ( - (pt if check_is_pt else python_type, pt) - for pt in python_type_type.__mro__ - ) + elif isinstance(python_type, type): + python_type_type = python_type + search = ((pt, pt) for pt in python_type_type.__mro__) else: - python_type_type = python_type_to_check # type: ignore[assignment] + python_type_type = python_type # type: ignore[assignment] search = ((python_type, python_type_type),) for pt, flattened in search: @@ -1290,6 +1278,39 @@ class registry: if resolved_sql_type is not None: return resolved_sql_type + # 2.0 fallbacks + if _do_fallbacks: + python_type_to_check: Any = None + kind = None + if is_pep695(python_type): + # NOTE: assume there aren't type alias types of new types. + python_type_to_check = python_type + while is_pep695(python_type_to_check): + python_type_to_check = python_type_to_check.__value__ + python_type_to_check = de_optionalize_union_types( + python_type_to_check + ) + kind = "TypeAliasType" + if is_newtype(python_type): + python_type_to_check = flatten_newtype(python_type) + kind = "NewType" + + if python_type_to_check is not None: + res_after_fallback = self._resolve_type( + python_type_to_check, False + ) + if res_after_fallback is not None: + assert kind is not None + warn_deprecated( + f"Matching the provided {kind} '{python_type}' on " + "its resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to " + "the type_annotation_map to allow it to match " + "explicitly.", + "2.0", + ) + return res_after_fallback + return None @property diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index b069d23c0f..aa64eaa666 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -65,11 +65,11 @@ from ..sql.schema import Column from ..sql.schema import Table from ..util import topological from ..util.typing import _AnnotationScanType +from ..util.typing import get_args from ..util.typing import is_fwd_ref from ..util.typing import is_literal from ..util.typing import Protocol from ..util.typing import TypedDict -from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _ClassDict @@ -1319,7 +1319,7 @@ class _ClassScanMapperConfig(_MapperConfig): extracted_mapped_annotation, mapped_container = extracted if attr_value is None and not is_literal(extracted_mapped_annotation): - for elem in typing_get_args(extracted_mapped_annotation): + for elem in get_args(extracted_mapped_annotation): if isinstance(elem, str) or is_fwd_ref( elem, check_generic=True ): diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index faf287cce6..4e07050a1d 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -53,9 +53,10 @@ from .. import util from ..sql import expression from ..sql import operators from ..sql.elements import BindParameter +from ..util.typing import get_args from ..util.typing import is_fwd_ref from ..util.typing import is_pep593 -from ..util.typing import typing_get_args + if typing.TYPE_CHECKING: from ._typing import _InstanceDict @@ -364,7 +365,7 @@ class CompositeProperty( argument = extracted_mapped_annotation if is_pep593(argument): - argument = typing_get_args(argument)[0] + argument = get_args(argument)[0] if argument and self.composite_class is None: if isinstance(argument, str) or is_fwd_ref( diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index b6fb3d43e3..96ae9d7f82 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -55,13 +55,13 @@ from ..sql.schema import Column from ..sql.schema import SchemaConst from ..sql.type_api import TypeEngine from ..util.typing import de_optionalize_union_types +from ..util.typing import get_args +from ..util.typing import includes_none from ..util.typing import is_fwd_ref -from ..util.typing import is_optional_union from ..util.typing import is_pep593 from ..util.typing import is_pep695 from ..util.typing import is_union from ..util.typing import Self -from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _IdentityKeyType @@ -752,38 +752,36 @@ class MappedColumn( cls, argument, originating_module ) - nullable = is_optional_union(argument) + nullable = includes_none(argument) if not self._has_nullable: self.column.nullable = nullable our_type = de_optionalize_union_types(argument) - use_args_from = None - - our_original_type = our_type - - if is_pep695(our_type): - our_type = our_type.__value__ + find_mapped_in: Tuple[Any, ...] = () + our_type_is_pep593 = False + raw_pep_593_type = None if is_pep593(our_type): our_type_is_pep593 = True - pep_593_components = typing_get_args(our_type) + pep_593_components = get_args(our_type) raw_pep_593_type = pep_593_components[0] - if is_optional_union(raw_pep_593_type): + if nullable: raw_pep_593_type = de_optionalize_union_types(raw_pep_593_type) - - nullable = True - if not self._has_nullable: - self.column.nullable = nullable - for elem in pep_593_components[1:]: - if isinstance(elem, MappedColumn): - use_args_from = elem - break + find_mapped_in = pep_593_components[1:] + elif is_pep695(argument) and is_pep593(argument.__value__): + # do not support nested annotation inside unions ets + find_mapped_in = get_args(argument.__value__)[1:] + + use_args_from: Optional[MappedColumn[Any]] + for elem in find_mapped_in: + if isinstance(elem, MappedColumn): + use_args_from = elem + break else: - our_type_is_pep593 = False - raw_pep_593_type = None + use_args_from = None if use_args_from is not None: if ( @@ -857,10 +855,11 @@ class MappedColumn( if sqltype._isnull and not self.column.foreign_keys: new_sqltype = None + checks: List[Any] if our_type_is_pep593: - checks = [our_original_type, raw_pep_593_type] + checks = [our_type, raw_pep_593_type] else: - checks = [our_original_type] + checks = [our_type] for check_type in checks: new_sqltype = registry._resolve_type(check_type) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index dbfa6d5f1b..11b6ac2c1c 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -91,10 +91,10 @@ from ..util.typing import ( ) from ..util.typing import eval_name_only as _eval_name_only from ..util.typing import fixup_container_fwd_refs +from ..util.typing import get_origin from ..util.typing import is_origin_of_cls from ..util.typing import Literal from ..util.typing import Protocol -from ..util.typing import typing_get_origin if typing.TYPE_CHECKING: from ._typing import _EntityType @@ -123,7 +123,7 @@ if typing.TYPE_CHECKING: from ..sql.selectable import Selectable from ..sql.visitors import anon_map from ..util.typing import _AnnotationScanType - from ..util.typing import ArgsTypeProcotol + from ..util.typing import ArgsTypeProtocol _T = TypeVar("_T", bound=Any) @@ -177,7 +177,7 @@ class _DeStringifyUnionElements(Protocol): def __call__( self, cls: Type[Any], - annotation: ArgsTypeProcotol, + annotation: ArgsTypeProtocol, originating_module: str, *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, @@ -1543,7 +1543,7 @@ GenericAlias = type(List[Any]) def _inspect_generic_alias( class_: Type[_O], ) -> Optional[Mapper[_O]]: - origin = cast("Type[_O]", typing_get_origin(class_)) + origin = cast("Type[_O]", get_origin(class_)) return _inspect_mc(origin) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index f16db64066..a7d140ec6b 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -59,9 +59,11 @@ from .. import util from ..engine import processors from ..util import langhelpers from ..util import OrderedDict +from ..util import warn_deprecated +from ..util.typing import get_args from ..util.typing import is_literal +from ..util.typing import is_pep695 from ..util.typing import Literal -from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _ColumnExpressionArgument @@ -1511,6 +1513,19 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): native_enum = None + def process_literal(pt): + # for a literal, where we need to get its contents, parse it out. + enum_args = get_args(pt) + bad_args = [arg for arg in enum_args if not isinstance(arg, str)] + if bad_args: + raise exc.ArgumentError( + f"Can't create string-based Enum datatype from non-string " + f"values: {', '.join(repr(x) for x in bad_args)}. Please " + f"provide an explicit Enum datatype for this Python type" + ) + native_enum = False + return enum_args, native_enum + if not we_are_generic_form and python_type is matched_on: # if we have enumerated values, and the incoming python # type is exactly the one that matched in the type map, @@ -1519,16 +1534,32 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): enum_args = self._enums_argument elif is_literal(python_type): - # for a literal, where we need to get its contents, parse it out. - enum_args = typing_get_args(python_type) - bad_args = [arg for arg in enum_args if not isinstance(arg, str)] - if bad_args: + enum_args, native_enum = process_literal(python_type) + elif is_pep695(python_type): + value = python_type.__value__ + if is_pep695(value): + new_value = value + while is_pep695(new_value): + new_value = new_value.__value__ + if is_literal(new_value): + value = new_value + warn_deprecated( + f"Mapping recursive TypeAliasType '{python_type}' " + "that resolve to literal to generate an Enum is " + "deprecated. SQLAlchemy 2.1 will not support this " + "use case. Please avoid using recursing " + "TypeAliasType.", + "2.0", + ) + if not is_literal(value): raise exc.ArgumentError( - f"Can't create string-based Enum datatype from non-string " - f"values: {', '.join(repr(x) for x in bad_args)}. Please " - f"provide an explicit Enum datatype for this Python type" + f"Can't associate TypeAliasType '{python_type}' to an " + "Enum since it's not a direct alias of a Literal. Only " + "aliases in this form `type my_alias = Literal['a', " + "'b']` are supported when generating Enums." ) - native_enum = False + enum_args, native_enum = process_literal(value) + elif isinstance(python_type, type) and issubclass( python_type, enum.Enum ): diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index bd1ebd4c01..645a41a240 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -9,6 +9,7 @@ from __future__ import annotations import builtins +from collections import deque import collections.abc as collections_abc import re import sys @@ -54,6 +55,7 @@ if True: # zimports removes the tailing comments from typing_extensions import TypeGuard as TypeGuard # 3.10 from typing_extensions import Self as Self # 3.11 from typing_extensions import TypeAliasType as TypeAliasType # 3.12 + from typing_extensions import Never as Never # 3.11 _T = TypeVar("_T", bound=Any) _KT = TypeVar("_KT") @@ -65,9 +67,9 @@ _VT_co = TypeVar("_VT_co", covariant=True) if compat.py38: # typing_extensions.Literal is different from typing.Literal until # Python 3.10.1 - _LITERAL_TYPES = frozenset([typing.Literal, Literal]) + LITERAL_TYPES = frozenset([typing.Literal, Literal]) else: - _LITERAL_TYPES = frozenset([Literal]) + LITERAL_TYPES = frozenset([Literal]) if compat.py310: @@ -79,16 +81,13 @@ else: NoneFwd = ForwardRef("None") -typing_get_args = get_args -typing_get_origin = get_origin - _AnnotationScanType = Union[ Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]" ] -class ArgsTypeProcotol(Protocol): +class ArgsTypeProtocol(Protocol): """protocol for types that have ``__args__`` there's no public interface for this AFAIK @@ -209,7 +208,7 @@ def fixup_container_fwd_refs( if ( is_generic(type_) - and typing_get_origin(type_) + and get_origin(type_) in ( dict, set, @@ -229,11 +228,11 @@ def fixup_container_fwd_refs( ) ): # compat with py3.10 and earlier - return typing_get_origin(type_).__class_getitem__( # type: ignore + return get_origin(type_).__class_getitem__( # type: ignore tuple( [ ForwardRef(elem) if isinstance(elem, str) else elem - for elem in typing_get_args(type_) + for elem in get_args(type_) ] ) ) @@ -332,7 +331,7 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str: def de_stringify_union_elements( cls: Type[Any], - annotation: ArgsTypeProcotol, + annotation: ArgsTypeProtocol, originating_module: str, locals_: Mapping[str, Any], *, @@ -352,8 +351,8 @@ def de_stringify_union_elements( ) -def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: - return type_ is not None and typing_get_origin(type_) is Annotated +def is_pep593(type_: Optional[Any]) -> bool: + return type_ is not None and get_origin(type_) is Annotated def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]: @@ -362,8 +361,8 @@ def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]: ) -def is_literal(type_: _AnnotationScanType) -> bool: - return get_origin(type_) in _LITERAL_TYPES +def is_literal(type_: Any) -> bool: + return get_origin(type_) in LITERAL_TYPES def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]: @@ -389,6 +388,43 @@ def flatten_newtype(type_: NewType) -> Type[Any]: return super_type # type: ignore[return-value] +def pep695_values(type_: _AnnotationScanType) -> Set[Any]: + """Extracts the value from a TypeAliasType, recursively exploring unions + and inner TypeAliasType to flatten them into a single set. + + Forward references are not evaluated, so no recursive exploration happens + into them. + """ + _seen = set() + + def recursive_value(type_): + if type_ in _seen: + # recursion are not supported (at least it's flagged as + # an error by pyright). Just avoid infinite loop + return type_ + _seen.add(type_) + if not is_pep695(type_): + return type_ + value = type_.__value__ + if not is_union(value): + return value + return [recursive_value(t) for t in value.__args__] + + res = recursive_value(type_) + if isinstance(res, list): + types = set() + stack = deque(res) + while stack: + t = stack.popleft() + if isinstance(t, list): + stack.extend(t) + else: + types.add(None if t in {NoneType, NoneFwd} else t) + return types + else: + return {res} + + def is_fwd_ref( type_: _AnnotationScanType, check_generic: bool = False ) -> TypeGuard[ForwardRef]: @@ -422,13 +458,10 @@ def de_optionalize_union_types( """ - while is_pep695(type_): - type_ = type_.__value__ - if is_fwd_ref(type_): - return de_optionalize_fwd_ref_union_types(type_) + return _de_optionalize_fwd_ref_union_types(type_, False) - elif is_optional(type_): + elif is_union(type_) and includes_none(type_): typ = set(type_.__args__) typ.discard(NoneType) @@ -440,9 +473,21 @@ def de_optionalize_union_types( return type_ -def de_optionalize_fwd_ref_union_types( - type_: ForwardRef, -) -> _AnnotationScanType: +@overload +def _de_optionalize_fwd_ref_union_types( + type_: ForwardRef, return_has_none: Literal[True] +) -> bool: ... + + +@overload +def _de_optionalize_fwd_ref_union_types( + type_: ForwardRef, return_has_none: Literal[False] +) -> _AnnotationScanType: ... + + +def _de_optionalize_fwd_ref_union_types( + type_: ForwardRef, return_has_none: bool +) -> Union[_AnnotationScanType, bool]: """return the non-optional type for Optional[], Union[None, ...], x|None, etc. without de-stringifying forward refs. @@ -454,47 +499,77 @@ def de_optionalize_fwd_ref_union_types( mm = re.match(r"^(.+?)\[(.+)\]$", annotation) if mm: - if mm.group(1) == "Optional": - return ForwardRef(mm.group(2)) - elif mm.group(1) == "Union": - elements = re.split(r",\s*", mm.group(2)) - return make_union_type( - *[ForwardRef(elem) for elem in elements if elem != "None"] - ) + g1 = mm.group(1).split(".")[-1] + if g1 == "Optional": + return True if return_has_none else ForwardRef(mm.group(2)) + elif g1 == "Union": + if "[" in mm.group(2): + # cases like "Union[Dict[str, int], int, None]" + elements: list[str] = [] + current: list[str] = [] + ignore_comma = 0 + for char in mm.group(2): + if char == "[": + ignore_comma += 1 + elif char == "]": + ignore_comma -= 1 + elif ignore_comma == 0 and char == ",": + elements.append("".join(current).strip()) + current.clear() + continue + current.append(char) + else: + elements = re.split(r",\s*", mm.group(2)) + parts = [ForwardRef(elem) for elem in elements if elem != "None"] + if return_has_none: + return len(elements) != len(parts) + else: + return make_union_type(*parts) if parts else Never # type: ignore[return-value] # noqa: E501 else: - return type_ + return False if return_has_none else type_ pipe_tokens = re.split(r"\s*\|\s*", annotation) - if "None" in pipe_tokens: - return ForwardRef("|".join(p for p in pipe_tokens if p != "None")) + has_none = "None" in pipe_tokens + if return_has_none: + return has_none + if has_none: + anno_str = "|".join(p for p in pipe_tokens if p != "None") + return ForwardRef(anno_str) if anno_str else Never # type: ignore[return-value] # noqa: E501 return type_ def make_union_type(*types: _AnnotationScanType) -> Type[Any]: - """Make a Union type. + """Make a Union type.""" + return Union.__getitem__(types) # type: ignore - This is needed by :func:`.de_optionalize_union_types` which removes - ``NoneType`` from a ``Union``. - """ - return cast(Any, Union).__getitem__(types) # type: ignore - - -def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]: - return is_origin_of( - type_, - "Optional", - "Union", - "UnionType", - ) +def includes_none(type_: Any) -> bool: + """Returns if the type annotation ``type_`` allows ``None``. - -def is_optional_union(type_: Any) -> bool: - return is_optional(type_) and NoneType in typing_get_args(type_) - - -def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]: + This function supports: + * forward refs + * unions + * pep593 - Annotated + * pep695 - TypeAliasType (does not support looking into + fw reference of other pep695) + * NewType + * plain types like ``int``, ``None``, etc + """ + if is_fwd_ref(type_): + return _de_optionalize_fwd_ref_union_types(type_, True) + if is_union(type_): + return any(includes_none(t) for t in get_args(type_)) + if is_pep593(type_): + return includes_none(get_args(type_)[0]) + if is_pep695(type_): + return any(includes_none(t) for t in pep695_values(type_)) + if is_newtype(type_): + return includes_none(type_.__supertype__) + return type_ in (NoneFwd, NoneType, None) + + +def is_union(type_: Any) -> TypeGuard[ArgsTypeProtocol]: return is_origin_of(type_, "Union", "UnionType") @@ -504,7 +579,7 @@ def is_origin_of_cls( """return True if the given type has an __origin__ that shares a base with the given class""" - origin = typing_get_origin(type_) + origin = get_origin(type_) if origin is None: return False @@ -517,7 +592,7 @@ def is_origin_of( """return True if the given type has an __origin__ with the given name and optional module.""" - origin = typing_get_origin(type_) + origin = get_origin(type_) if origin is None: return False @@ -607,6 +682,3 @@ class CallableReference(Generic[_FN]): def __set__(self, instance: Any, value: _FN) -> None: ... def __delete__(self, instance: Any) -> None: ... - - -# $def ro_descriptor_reference(fn: Callable[]) diff --git a/test/base/test_typing_utils.py b/test/base/test_typing_utils.py new file mode 100644 index 0000000000..67e7bf4143 --- /dev/null +++ b/test/base/test_typing_utils.py @@ -0,0 +1,409 @@ +# NOTE: typing implementation is full of heuristic so unit test it to avoid +# unexpected breakages. + +import typing + +import typing_extensions + +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import requires +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.assertions import is_ +from sqlalchemy.util import py310 +from sqlalchemy.util import py311 +from sqlalchemy.util import py312 +from sqlalchemy.util import py38 +from sqlalchemy.util import typing as sa_typing + +TV = typing.TypeVar("TV") + + +def union_types(): + res = [typing.Union[int, str]] + if py310: + res.append(int | str) + return res + + +def null_union_types(): + res = [ + typing.Optional[typing.Union[int, str]], + typing.Union[int, str, None], + typing.Union[int, str, "None"], + ] + if py310: + res.append(int | str | None) + res.append(typing.Optional[int | str]) + res.append(typing.Union[int, str] | None) + res.append(typing.Optional[int] | str) + return res + + +def make_fw_ref(anno: str) -> typing.ForwardRef: + return typing.Union[anno] + + +TA_int = typing_extensions.TypeAliasType("TA_int", int) +TA_union = typing_extensions.TypeAliasType("TA_union", typing.Union[int, str]) +TA_null_union = typing_extensions.TypeAliasType( + "TA_null_union", typing.Union[int, str, None] +) +TA_null_union2 = typing_extensions.TypeAliasType( + "TA_null_union2", typing.Union[int, str, "None"] +) +TA_null_union3 = typing_extensions.TypeAliasType( + "TA_null_union3", typing.Union[int, "typing.Union[None, bool]"] +) +TA_null_union4 = typing_extensions.TypeAliasType( + "TA_null_union4", typing.Union[int, "TA_null_union2"] +) +TA_union_ta = typing_extensions.TypeAliasType( + "TA_union_ta", typing.Union[TA_int, str] +) +TA_null_union_ta = typing_extensions.TypeAliasType( + "TA_null_union_ta", typing.Union[TA_null_union, float] +) +TA_list = typing_extensions.TypeAliasType( + "TA_list", typing.Union[int, str, typing.List["TA_list"]] +) +# these below not valid. Verify that it does not cause exceptions in any case +TA_recursive = typing_extensions.TypeAliasType( + "TA_recursive", typing.Union["TA_recursive", str] +) +TA_null_recursive = typing_extensions.TypeAliasType( + "TA_null_recursive", typing.Union[TA_recursive, None] +) +TA_recursive_a = typing_extensions.TypeAliasType( + "TA_recursive_a", typing.Union["TA_recursive_b", int] +) +TA_recursive_b = typing_extensions.TypeAliasType( + "TA_recursive_b", typing.Union["TA_recursive_a", str] +) + + +def type_aliases(): + return [ + TA_int, + TA_union, + TA_null_union, + TA_null_union2, + TA_null_union3, + TA_null_union4, + TA_union_ta, + TA_null_union_ta, + TA_list, + TA_recursive, + TA_null_recursive, + TA_recursive_a, + TA_recursive_b, + ] + + +NT_str = typing.NewType("NT_str", str) +NT_null = typing.NewType("NT_null", None) +# this below is not valid. Verify that it does not cause exceptions in any case +NT_union = typing.NewType("NT_union", typing.Union[str, int]) + + +def new_types(): + return [NT_str, NT_null, NT_union] + + +A_str = typing_extensions.Annotated[str, "meta"] +A_null_str = typing_extensions.Annotated[ + typing.Union[str, None], "other_meta", "null" +] +A_union = typing_extensions.Annotated[typing.Union[str, int], "other_meta"] +A_null_union = typing_extensions.Annotated[ + typing.Union[str, int, None], "other_meta", "null" +] + + +def annotated_l(): + return [A_str, A_null_str, A_union, A_null_union] + + +def all_types(): + return ( + union_types() + + null_union_types() + + type_aliases() + + new_types() + + annotated_l() + ) + + +def exec_code(code: str, *vars: str) -> typing.Any: + assert vars + scope = {} + exec(code, None, scope) + if len(vars) == 1: + return scope[vars[0]] + return [scope[name] for name in vars] + + +class TestTestingThings(fixtures.TestBase): + def test_unions_are_the_same(self): + # no need to test typing_extensions.Union, typing_extensions.Optional + is_(typing.Union, typing_extensions.Union) + is_(typing.Optional, typing_extensions.Optional) + if py312: + is_(typing.TypeAliasType, typing_extensions.TypeAliasType) + + def test_make_union(self): + v = int, str + eq_(typing.Union[int, str], typing.Union.__getitem__(v)) + if py311: + # need eval since it's a syntax error in python < 3.11 + eq_(typing.Union[int, str], eval("typing.Union[*(int, str)]")) + eq_(typing.Union[int, str], eval("typing.Union[*v]")) + + @requires.python312 + def test_make_type_alias_type(self): + # verify that TypeAliasType('foo', int) it the same as 'type foo = int' + x_type = exec_code("type x = int", "x") + x = typing.TypeAliasType("x", int) + + eq_(type(x_type), type(x)) + eq_(x_type.__name__, x.__name__) + eq_(x_type.__value__, x.__value__) + + def test_make_fw_ref(self): + eq_(make_fw_ref("str"), typing.ForwardRef("str")) + eq_(make_fw_ref("str|int"), typing.ForwardRef("str|int")) + eq_( + make_fw_ref("Optional[Union[str, int]]"), + typing.ForwardRef("Optional[Union[str, int]]"), + ) + + +class TestTyping(fixtures.TestBase): + def test_is_pep593(self): + eq_(sa_typing.is_pep593(str), False) + eq_(sa_typing.is_pep593(None), False) + eq_(sa_typing.is_pep593(typing_extensions.Annotated[int, "a"]), True) + if py310: + eq_(sa_typing.is_pep593(typing.Annotated[int, "a"]), True) + + for t in annotated_l(): + eq_(sa_typing.is_pep593(t), True) + for t in ( + union_types() + null_union_types() + type_aliases() + new_types() + ): + eq_(sa_typing.is_pep593(t), False) + + def test_is_literal(self): + if py38: + eq_(sa_typing.is_literal(typing.Literal["a"]), True) + eq_(sa_typing.is_literal(typing_extensions.Literal["a"]), True) + eq_(sa_typing.is_literal(None), False) + for t in all_types(): + eq_(sa_typing.is_literal(t), False) + + def test_is_newtype(self): + eq_(sa_typing.is_newtype(str), False) + + for t in new_types(): + eq_(sa_typing.is_newtype(t), True) + for t in ( + union_types() + null_union_types() + type_aliases() + annotated_l() + ): + eq_(sa_typing.is_newtype(t), False) + + def test_is_generic(self): + class W(typing.Generic[TV]): + pass + + eq_(sa_typing.is_generic(typing.List[int]), True) + eq_(sa_typing.is_generic(W), False) + eq_(sa_typing.is_generic(W[str]), True) + + if py312: + t = exec_code("class W[T]: pass", "W") + eq_(sa_typing.is_generic(t), False) + eq_(sa_typing.is_generic(t[int]), True) + + for t in all_types(): + eq_(sa_typing.is_literal(t), False) + + def test_is_pep695(self): + eq_(sa_typing.is_pep695(str), False) + for t in ( + union_types() + null_union_types() + new_types() + annotated_l() + ): + eq_(sa_typing.is_pep695(t), False) + for t in type_aliases(): + eq_(sa_typing.is_pep695(t), True) + + def test_pep695_value(self): + eq_(sa_typing.pep695_values(int), {int}) + eq_( + sa_typing.pep695_values(typing.Union[int, str]), + {typing.Union[int, str]}, + ) + + for t in ( + union_types() + null_union_types() + new_types() + annotated_l() + ): + eq_(sa_typing.pep695_values(t), {t}) + + eq_( + sa_typing.pep695_values(typing.Union[int, TA_int]), + {typing.Union[int, TA_int]}, + ) + + eq_(sa_typing.pep695_values(TA_int), {int}) + eq_(sa_typing.pep695_values(TA_union), {int, str}) + eq_(sa_typing.pep695_values(TA_null_union), {int, str, None}) + eq_(sa_typing.pep695_values(TA_null_union2), {int, str, None}) + eq_( + sa_typing.pep695_values(TA_null_union3), + {int, typing.ForwardRef("typing.Union[None, bool]")}, + ) + eq_( + sa_typing.pep695_values(TA_null_union4), + {int, typing.ForwardRef("TA_null_union2")}, + ) + eq_(sa_typing.pep695_values(TA_union_ta), {int, str}) + eq_(sa_typing.pep695_values(TA_null_union_ta), {int, str, None, float}) + eq_( + sa_typing.pep695_values(TA_list), + {int, str, typing.List[typing.ForwardRef("TA_list")]}, + ) + eq_( + sa_typing.pep695_values(TA_recursive), + {typing.ForwardRef("TA_recursive"), str}, + ) + eq_( + sa_typing.pep695_values(TA_null_recursive), + {typing.ForwardRef("TA_recursive"), str, None}, + ) + eq_( + sa_typing.pep695_values(TA_recursive_a), + {typing.ForwardRef("TA_recursive_b"), int}, + ) + eq_( + sa_typing.pep695_values(TA_recursive_b), + {typing.ForwardRef("TA_recursive_a"), str}, + ) + + def test_is_fwd_ref(self): + eq_(sa_typing.is_fwd_ref(int), False) + eq_(sa_typing.is_fwd_ref(make_fw_ref("str")), True) + eq_(sa_typing.is_fwd_ref(typing.Union[str, int]), False) + eq_(sa_typing.is_fwd_ref(typing.Union["str", int]), False) + eq_(sa_typing.is_fwd_ref(typing.Union["str", int], True), True) + + for t in all_types(): + eq_(sa_typing.is_fwd_ref(t), False) + + def test_de_optionalize_union_types(self): + fn = sa_typing.de_optionalize_union_types + + eq_( + fn(typing.Optional[typing.Union[int, str]]), typing.Union[int, str] + ) + eq_(fn(typing.Union[int, str, None]), typing.Union[int, str]) + eq_(fn(typing.Union[int, str, "None"]), typing.Union[int, str]) + + eq_(fn(make_fw_ref("None")), typing_extensions.Never) + eq_(fn(make_fw_ref("typing.Union[None]")), typing_extensions.Never) + eq_(fn(make_fw_ref("Union[None, str]")), typing.ForwardRef("str")) + eq_( + fn(make_fw_ref("Union[None, str, int]")), + typing.Union["str", "int"], + ) + eq_(fn(make_fw_ref("Optional[int]")), typing.ForwardRef("int")) + eq_( + fn(make_fw_ref("typing.Optional[Union[int | str]]")), + typing.ForwardRef("Union[int | str]"), + ) + + for t in null_union_types(): + res = fn(t) + eq_(sa_typing.is_union(res), True) + eq_(type(None) not in res.__args__, True) + + for t in union_types() + type_aliases() + new_types() + annotated_l(): + eq_(fn(t), t) + + eq_( + fn(make_fw_ref("Union[typing.Dict[str, int], int, None]")), + typing.Union["typing.Dict[str, int]", "int"], + ) + + def test_make_union_type(self): + eq_(sa_typing.make_union_type(int), int) + eq_(sa_typing.make_union_type(None), type(None)) + eq_(sa_typing.make_union_type(int, str), typing.Union[int, str]) + eq_( + sa_typing.make_union_type(int, typing.Optional[str]), + typing.Union[int, str, None], + ) + eq_( + sa_typing.make_union_type(int, typing.Union[str, bool]), + typing.Union[int, str, bool], + ) + eq_( + sa_typing.make_union_type(bool, TA_int, NT_str), + typing.Union[bool, TA_int, NT_str], + ) + + def test_includes_none(self): + eq_(sa_typing.includes_none(None), True) + eq_(sa_typing.includes_none(type(None)), True) + eq_(sa_typing.includes_none(typing.ForwardRef("None")), True) + eq_(sa_typing.includes_none(int), False) + for t in union_types(): + eq_(sa_typing.includes_none(t), False) + + for t in null_union_types(): + eq_(sa_typing.includes_none(t), True, str(t)) + + # TODO: these are false negatives + false_negative = { + TA_null_union4, # does not evaluate FW ref + } + for t in type_aliases() + new_types(): + if t in false_negative: + exp = False + else: + exp = "null" in t.__name__ + eq_(sa_typing.includes_none(t), exp, str(t)) + + for t in annotated_l(): + eq_( + sa_typing.includes_none(t), + "null" in sa_typing.get_args(t), + str(t), + ) + # nested things + eq_(sa_typing.includes_none(typing.Union[int, "None"]), True) + eq_(sa_typing.includes_none(typing.Union[bool, TA_null_union]), True) + eq_(sa_typing.includes_none(typing.Union[bool, NT_null]), True) + # nested fw + eq_( + sa_typing.includes_none( + typing.Union[int, "typing.Union[str, None]"] + ), + True, + ) + eq_( + sa_typing.includes_none( + typing.Union[int, "typing.Union[int, str]"] + ), + False, + ) + + # there are not supported. should return True + eq_( + sa_typing.includes_none(typing.Union[bool, "TA_null_union"]), False + ) + eq_(sa_typing.includes_none(typing.Union[bool, "NT_null"]), False) + + def test_is_union(self): + eq_(sa_typing.is_union(str), False) + for t in union_types() + null_union_types(): + eq_(sa_typing.is_union(t), True) + for t in type_aliases() + new_types() + annotated_l(): + eq_(sa_typing.is_union(t), False) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 85c419e94e..de8712c852 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -4,9 +4,6 @@ import inspect from pathlib import Path import pickle import sys -import typing - -import typing_extensions from sqlalchemy import exc from sqlalchemy import sql @@ -42,7 +39,6 @@ from sqlalchemy.util import WeakSequence from sqlalchemy.util._collections import merge_lists_w_ordering from sqlalchemy.util._has_cy import _import_cy_extensions from sqlalchemy.util._has_cy import HAS_CYEXTENSION -from sqlalchemy.util.typing import is_union class WeakSequenceTest(fixtures.TestBase): @@ -3634,11 +3630,3 @@ class CyExtensionTest(fixtures.TestBase): for f in cython_files } eq_({m.__name__ for m in ext}, set(names)) - - -class TypingTest(fixtures.TestBase): - def test_is_union(self): - assert is_union(typing.Union[str, int]) - assert is_union(typing_extensions.Union[str, int]) - if compat.py310: - assert is_union(str | int) diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index c34d54169e..165f43b42d 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -1,8 +1,8 @@ """This file includes annotation-sensitive tests while having ``from __future__ import annotations`` in effect. -Only tests that don't have an equivalent in ``test_typed_mappings`` are -specified here. All test from ``test_typed_mappings`` are copied over to +Only tests that don't have an equivalent in ``test_typed_mapping`` are +specified here. All test from ``test_typed_mapping`` are copied over to the ``test_tm_future_annotations_sync`` by the ``sync_test_file`` script. """ diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index a2eac4d7f4..4b37926638 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -96,8 +96,9 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true -from sqlalchemy.testing import skip_test +from sqlalchemy.testing import requires from sqlalchemy.testing import Variation +from sqlalchemy.testing.assertions import ne_ from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat from sqlalchemy.util.typing import Annotated @@ -118,11 +119,6 @@ _StrTypeAlias: TypeAlias = str _StrPep695: TypeAlias = str _UnionPep695: TypeAlias = Union[_SomeDict1, _SomeDict2] -_Literal695: TypeAlias = Literal["to-do", "in-progress", "done"] -_Recursive695_0: TypeAlias = _Literal695 -_Recursive695_1: TypeAlias = _Recursive695_0 -_Recursive695_2: TypeAlias = _Recursive695_1 - if compat.py38: _TypingLiteral = typing.Literal["a", "b"] _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] @@ -147,16 +143,16 @@ type _UnionPep695 = _SomeDict1 | _SomeDict2 type _StrPep695 = str type strtypalias_keyword = Annotated[str, mapped_column(info={"hi": "there"})] - -strtypalias_tat: typing.TypeAliasType = Annotated[ +type strtypalias_keyword_nested = int | Annotated[ + str, mapped_column(info={"hi": "there"})] +strtypalias_ta: typing.TypeAlias = Annotated[ str, mapped_column(info={"hi": "there"})] - strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})] type _Literal695 = Literal["to-do", "in-progress", "done"] -type _Recursive695_0 = _Literal695 -type _Recursive695_1 = _Recursive695_0 -type _Recursive695_2 = _Recursive695_1 +type _RecursiveLiteral695 = _Literal695 + +type _JsonPep695 = _JsonPep604 """, globals(), ) @@ -856,6 +852,84 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): eq_(Test.__table__.c.data.type.length, 30) is_(Test.__table__.c.structure.type._type_affinity, JSON) + @testing.variation( + "option", + [ + "plain", + "union", + "union_604", + "union_null", + "union_null_604", + "optional", + "optional_union", + "optional_union_604", + ], + ) + @testing.variation("in_map", ["yes", "no", "value"]) + @testing.requires.python312 + def test_pep695_behavior(self, decl_base, in_map, option): + """Issue #11955""" + global tat + + if option.plain: + tat = TypeAliasType("tat", str) + elif option.union: + tat = TypeAliasType("tat", Union[str, int]) + elif option.union_604: + tat = TypeAliasType("tat", str | int) + elif option.union_null: + tat = TypeAliasType("tat", Union[str, int, None]) + elif option.union_null_604: + tat = TypeAliasType("tat", str | int | None) + elif option.optional: + tat = TypeAliasType("tat", Optional[str]) + elif option.optional_union: + tat = TypeAliasType("tat", Optional[Union[str, int]]) + elif option.optional_union_604: + tat = TypeAliasType("tat", Optional[str | int]) + else: + option.fail() + + if in_map.yes: + decl_base.registry.update_type_annotation_map({tat: String(99)}) + elif in_map.value: + decl_base.registry.update_type_annotation_map( + {tat.__value__: String(99)} + ) + + def declare(): + class Test(decl_base): + __tablename__ = "test" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[tat] + + return Test.__table__.c.data + + if in_map.yes: + col = declare() + length = 99 + elif in_map.value or option.optional or option.plain: + with expect_deprecated( + "Matching the provided TypeAliasType 'tat' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + col = declare() + length = 99 if in_map.value else None + else: + with expect_raises_message( + exc.ArgumentError, + "Could not locate SQLAlchemy Core type for Python type", + ): + declare() + return + + is_true(isinstance(col.type, String)) + eq_(col.type.length, length) + nullable = "null" in option.name or "optional" in option.name + eq_(col.nullable, nullable) + @testing.requires.python312 def test_pep695_typealias_as_typemap_keys( self, decl_base: Type[DeclarativeBase] @@ -876,12 +950,23 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): eq_(Test.__table__.c.data.type.length, 30) is_(Test.__table__.c.structure.type._type_affinity, JSON) - @testing.variation("alias_type", ["none", "typekeyword", "typealiastype"]) + @testing.variation( + "alias_type", + ["none", "typekeyword", "typealias", "typekeyword_nested"], + ) @testing.requires.python312 def test_extract_pep593_from_pep695( self, decl_base: Type[DeclarativeBase], alias_type ): """test #11130""" + if alias_type.typekeyword: + decl_base.registry.update_type_annotation_map( + {strtypalias_keyword: VARCHAR(33)} # noqa: F821 + ) + if alias_type.typekeyword_nested: + decl_base.registry.update_type_annotation_map( + {strtypalias_keyword_nested: VARCHAR(42)} # noqa: F821 + ) class MyClass(decl_base): __tablename__ = "my_table" @@ -890,33 +975,96 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): if alias_type.typekeyword: data_one: Mapped[strtypalias_keyword] # noqa: F821 - elif alias_type.typealiastype: - data_one: Mapped[strtypalias_tat] # noqa: F821 + elif alias_type.typealias: + data_one: Mapped[strtypalias_ta] # noqa: F821 elif alias_type.none: data_one: Mapped[strtypalias_plain] # noqa: F821 + elif alias_type.typekeyword_nested: + data_one: Mapped[strtypalias_keyword_nested] # noqa: F821 else: alias_type.fail() table = MyClass.__table__ assert table is not None - eq_(MyClass.data_one.expression.info, {"hi": "there"}) + if alias_type.typekeyword_nested: + # a nested annotation is not supported + eq_(MyClass.data_one.expression.info, {}) + else: + eq_(MyClass.data_one.expression.info, {"hi": "there"}) + if alias_type.typekeyword: + eq_(MyClass.data_one.type.length, 33) + elif alias_type.typekeyword_nested: + eq_(MyClass.data_one.type.length, 42) + else: + eq_(MyClass.data_one.type.length, None) + + @testing.variation("type_", ["literal", "recursive", "not_literal"]) + @testing.combinations(True, False, argnames="in_map") @testing.requires.python312 - def test_pep695_literal_defaults_to_enum(self, decl_base): + def test_pep695_literal_defaults_to_enum(self, decl_base, type_, in_map): """test #11305.""" - class Foo(decl_base): - __tablename__ = "footable" + def declare(): + class Foo(decl_base): + __tablename__ = "footable" - id: Mapped[int] = mapped_column(primary_key=True) - status: Mapped[_Literal695] - r2: Mapped[_Recursive695_2] + id: Mapped[int] = mapped_column(primary_key=True) + if type_.recursive: + status: Mapped[_RecursiveLiteral695] # noqa: F821 + elif type_.literal: + status: Mapped[_Literal695] # noqa: F821 + elif type_.not_literal: + status: Mapped[_StrPep695] # noqa: F821 + else: + type_.fail() + + return Foo - for col in (Foo.__table__.c.status, Foo.__table__.c.r2): + if in_map: + decl_base.registry.update_type_annotation_map( + { + _Literal695: Enum(enum.Enum), # noqa: F821 + _RecursiveLiteral695: Enum(enum.Enum), # noqa: F821 + _StrPep695: Enum(enum.Enum), # noqa: F821 + } + ) + if type_.recursive: + with expect_deprecated( + "Mapping recursive TypeAliasType '.+' that resolve to " + "literal to generate an Enum is deprecated. SQLAlchemy " + "2.1 will not support this use case. Please avoid using " + "recursing TypeAliasType", + ): + Foo = declare() + elif type_.literal: + Foo = declare() + else: + with expect_raises_message( + exc.ArgumentError, + "Can't associate TypeAliasType '.+' to an Enum " + "since it's not a direct alias of a Literal. Only " + "aliases in this form `type my_alias = Literal.'a', " + "'b'.` are supported when generating Enums.", + ): + declare() + return + else: + with expect_deprecated( + "Matching the provided TypeAliasType '.*' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + Foo = declare() + col = Foo.__table__.c.status + if in_map and not type_.not_literal: is_true(isinstance(col.type, Enum)) eq_(col.type.enums, ["to-do", "in-progress", "done"]) is_(col.type.native_enum, False) + else: + is_true(isinstance(col.type, String)) @testing.requires.python38 def test_typing_literal_identity(self, decl_base): @@ -1233,6 +1381,33 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): eq_(MyClass.__table__.c.data_four.type.length, 150) is_false(MyClass.__table__.c.data_four.nullable) + def test_newtype_missing_from_map(self, decl_base): + global str50 + + str50 = NewType("str50", str) + + if compat.py310: + text = ".*str50" + else: + # NewTypes before 3.10 had a very bad repr + # .new_type at 0x...> + text = ".*NewType.*" + + with expect_deprecated( + f"Matching the provided NewType '{text}' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + data_one: Mapped[str50] + + is_true(isinstance(MyClass.data_one.type, String)) + def test_extract_base_type_from_pep593( self, decl_base: Type[DeclarativeBase] ): @@ -1724,39 +1899,40 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): else: is_(getattr(Element.__table__.c.data, paramname), override_value) - @testing.variation("union", ["union", "pep604"]) - @testing.variation("typealias", ["legacy", "pep695"]) - def test_unions(self, union, typealias): + @testing.variation( + "union", + [ + "union", + ("pep604", requires.python310), + "union_null", + ("pep604_null", requires.python310), + ], + ) + def test_unions(self, union): + global UnionType our_type = Numeric(10, 2) if union.union: UnionType = Union[float, Decimal] + elif union.union_null: + UnionType = Union[float, Decimal, None] elif union.pep604: - if not compat.py310: - skip_test("Required Python 3.10") UnionType = float | Decimal + elif union.pep604_null: + UnionType = float | Decimal | None else: union.fail() - if typealias.legacy: - UnionTypeAlias = UnionType - elif typealias.pep695: - # same as type UnionTypeAlias = UnionType - UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType) - else: - typealias.fail() - class Base(DeclarativeBase): - type_annotation_map = {UnionTypeAlias: our_type} + type_annotation_map = {UnionType: our_type} class User(Base): __tablename__ = "users" - __table__: Table id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[Union[float, Decimal]] = mapped_column() - reverse_data: Mapped[Union[Decimal, float]] = mapped_column() + data: Mapped[Union[float, Decimal]] + reverse_data: Mapped[Union[Decimal, float]] optional_data: Mapped[Optional[Union[float, Decimal]]] = ( mapped_column() @@ -1773,6 +1949,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): mapped_column() ) + refer_union: Mapped[UnionType] + refer_union_optional: Mapped[Optional[UnionType]] + float_data: Mapped[float] = mapped_column() decimal_data: Mapped[Decimal] = mapped_column() @@ -1788,65 +1967,54 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): mapped_column() ) - if compat.py312: - MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal) - pep695_data: Mapped[MyTypeAlias] = mapped_column() - - is_(User.__table__.c.data.type, our_type) - is_false(User.__table__.c.data.nullable) - is_(User.__table__.c.reverse_data.type, our_type) - is_(User.__table__.c.optional_data.type, our_type) - is_true(User.__table__.c.optional_data.nullable) + info = [ + ("data", False), + ("reverse_data", False), + ("optional_data", True), + ("reverse_optional_data", True), + ("reverse_u_optional_data", True), + ("refer_union", "null" in union.name), + ("refer_union_optional", True), + ] + if compat.py310: + info += [ + ("pep604_data", False), + ("pep604_reverse", False), + ("pep604_optional", True), + ("pep604_data_fwd", False), + ("pep604_reverse_fwd", False), + ("pep604_optional_fwd", True), + ] - is_(User.__table__.c.reverse_optional_data.type, our_type) - is_(User.__table__.c.reverse_u_optional_data.type, our_type) - is_true(User.__table__.c.reverse_optional_data.nullable) - is_true(User.__table__.c.reverse_u_optional_data.nullable) + for name, nullable in info: + col = User.__table__.c[name] + is_(col.type, our_type, name) + is_(col.nullable, nullable, name) is_true(isinstance(User.__table__.c.float_data.type, Float)) + ne_(User.__table__.c.float_data.type, our_type) - is_not(User.__table__.c.decimal_data.type, our_type) + is_true(isinstance(User.__table__.c.decimal_data.type, Numeric)) + ne_(User.__table__.c.decimal_data.type, our_type) - if compat.py310: - for suffix in ("", "_fwd"): - data_col = User.__table__.c[f"pep604_data{suffix}"] - reverse_col = User.__table__.c[f"pep604_reverse{suffix}"] - optional_col = User.__table__.c[f"pep604_optional{suffix}"] - is_(data_col.type, our_type) - is_false(data_col.nullable) - is_(reverse_col.type, our_type) - is_false(reverse_col.nullable) - is_(optional_col.type, our_type) - is_true(optional_col.nullable) - - if compat.py312: - is_(User.__table__.c.pep695_data.type, our_type) - - @testing.variation("union", ["union", "pep604"]) + @testing.variation( + "union", + [ + "union", + ("pep604", requires.python310), + ("pep695", requires.python312), + ], + ) def test_optional_in_annotation_map(self, union): - """SQLAlchemy's behaviour is clear: an optional type means the column - is inferred as nullable. Some types which a user may want to put in the - type annotation map are already optional. JSON is a good example - because without any constraint, the type can be None via JSON null or - SQL NULL. - - By permitting optional types in the type annotation map, everything - just works, and mapped_column(nullable=False) is available if desired. - - See issue #11370 - """ + """See issue #11370""" class Base(DeclarativeBase): if union.union: - type_annotation_map = { - _Json: JSON, - } + type_annotation_map = {_Json: JSON} elif union.pep604: - if not compat.py310: - skip_test("Requires Python 3.10+") - type_annotation_map = { - _JsonPep604: JSON, - } + type_annotation_map = {_JsonPep604: JSON} + elif union.pep695: + type_annotation_map = {_JsonPep695: JSON} # noqa: F821 else: union.fail() @@ -1858,10 +2026,13 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): json1: Mapped[_Json] json2: Mapped[_Json] = mapped_column(nullable=False) elif union.pep604: - if not compat.py310: - skip_test("Requires Python 3.10+") json1: Mapped[_JsonPep604] json2: Mapped[_JsonPep604] = mapped_column(nullable=False) + elif union.pep695: + json1: Mapped[_JsonPep695] # noqa: F821 + json2: Mapped[_JsonPep695] = mapped_column( # noqa: F821 + nullable=False + ) else: union.fail() diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 5026e676a7..f1970f2183 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -87,8 +87,9 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true -from sqlalchemy.testing import skip_test +from sqlalchemy.testing import requires from sqlalchemy.testing import Variation +from sqlalchemy.testing.assertions import ne_ from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat from sqlalchemy.util.typing import Annotated @@ -109,11 +110,6 @@ _StrTypeAlias: TypeAlias = str _StrPep695: TypeAlias = str _UnionPep695: TypeAlias = Union[_SomeDict1, _SomeDict2] -_Literal695: TypeAlias = Literal["to-do", "in-progress", "done"] -_Recursive695_0: TypeAlias = _Literal695 -_Recursive695_1: TypeAlias = _Recursive695_0 -_Recursive695_2: TypeAlias = _Recursive695_1 - if compat.py38: _TypingLiteral = typing.Literal["a", "b"] _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] @@ -138,16 +134,16 @@ type _UnionPep695 = _SomeDict1 | _SomeDict2 type _StrPep695 = str type strtypalias_keyword = Annotated[str, mapped_column(info={"hi": "there"})] - -strtypalias_tat: typing.TypeAliasType = Annotated[ +type strtypalias_keyword_nested = int | Annotated[ + str, mapped_column(info={"hi": "there"})] +strtypalias_ta: typing.TypeAlias = Annotated[ str, mapped_column(info={"hi": "there"})] - strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})] type _Literal695 = Literal["to-do", "in-progress", "done"] -type _Recursive695_0 = _Literal695 -type _Recursive695_1 = _Recursive695_0 -type _Recursive695_2 = _Recursive695_1 +type _RecursiveLiteral695 = _Literal695 + +type _JsonPep695 = _JsonPep604 """, globals(), ) @@ -847,6 +843,84 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): eq_(Test.__table__.c.data.type.length, 30) is_(Test.__table__.c.structure.type._type_affinity, JSON) + @testing.variation( + "option", + [ + "plain", + "union", + "union_604", + "union_null", + "union_null_604", + "optional", + "optional_union", + "optional_union_604", + ], + ) + @testing.variation("in_map", ["yes", "no", "value"]) + @testing.requires.python312 + def test_pep695_behavior(self, decl_base, in_map, option): + """Issue #11955""" + # anno only: global tat + + if option.plain: + tat = TypeAliasType("tat", str) + elif option.union: + tat = TypeAliasType("tat", Union[str, int]) + elif option.union_604: + tat = TypeAliasType("tat", str | int) + elif option.union_null: + tat = TypeAliasType("tat", Union[str, int, None]) + elif option.union_null_604: + tat = TypeAliasType("tat", str | int | None) + elif option.optional: + tat = TypeAliasType("tat", Optional[str]) + elif option.optional_union: + tat = TypeAliasType("tat", Optional[Union[str, int]]) + elif option.optional_union_604: + tat = TypeAliasType("tat", Optional[str | int]) + else: + option.fail() + + if in_map.yes: + decl_base.registry.update_type_annotation_map({tat: String(99)}) + elif in_map.value: + decl_base.registry.update_type_annotation_map( + {tat.__value__: String(99)} + ) + + def declare(): + class Test(decl_base): + __tablename__ = "test" + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[tat] + + return Test.__table__.c.data + + if in_map.yes: + col = declare() + length = 99 + elif in_map.value or option.optional or option.plain: + with expect_deprecated( + "Matching the provided TypeAliasType 'tat' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + col = declare() + length = 99 if in_map.value else None + else: + with expect_raises_message( + exc.ArgumentError, + "Could not locate SQLAlchemy Core type for Python type", + ): + declare() + return + + is_true(isinstance(col.type, String)) + eq_(col.type.length, length) + nullable = "null" in option.name or "optional" in option.name + eq_(col.nullable, nullable) + @testing.requires.python312 def test_pep695_typealias_as_typemap_keys( self, decl_base: Type[DeclarativeBase] @@ -867,12 +941,23 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): eq_(Test.__table__.c.data.type.length, 30) is_(Test.__table__.c.structure.type._type_affinity, JSON) - @testing.variation("alias_type", ["none", "typekeyword", "typealiastype"]) + @testing.variation( + "alias_type", + ["none", "typekeyword", "typealias", "typekeyword_nested"], + ) @testing.requires.python312 def test_extract_pep593_from_pep695( self, decl_base: Type[DeclarativeBase], alias_type ): """test #11130""" + if alias_type.typekeyword: + decl_base.registry.update_type_annotation_map( + {strtypalias_keyword: VARCHAR(33)} # noqa: F821 + ) + if alias_type.typekeyword_nested: + decl_base.registry.update_type_annotation_map( + {strtypalias_keyword_nested: VARCHAR(42)} # noqa: F821 + ) class MyClass(decl_base): __tablename__ = "my_table" @@ -881,33 +966,96 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): if alias_type.typekeyword: data_one: Mapped[strtypalias_keyword] # noqa: F821 - elif alias_type.typealiastype: - data_one: Mapped[strtypalias_tat] # noqa: F821 + elif alias_type.typealias: + data_one: Mapped[strtypalias_ta] # noqa: F821 elif alias_type.none: data_one: Mapped[strtypalias_plain] # noqa: F821 + elif alias_type.typekeyword_nested: + data_one: Mapped[strtypalias_keyword_nested] # noqa: F821 else: alias_type.fail() table = MyClass.__table__ assert table is not None - eq_(MyClass.data_one.expression.info, {"hi": "there"}) + if alias_type.typekeyword_nested: + # a nested annotation is not supported + eq_(MyClass.data_one.expression.info, {}) + else: + eq_(MyClass.data_one.expression.info, {"hi": "there"}) + if alias_type.typekeyword: + eq_(MyClass.data_one.type.length, 33) + elif alias_type.typekeyword_nested: + eq_(MyClass.data_one.type.length, 42) + else: + eq_(MyClass.data_one.type.length, None) + + @testing.variation("type_", ["literal", "recursive", "not_literal"]) + @testing.combinations(True, False, argnames="in_map") @testing.requires.python312 - def test_pep695_literal_defaults_to_enum(self, decl_base): + def test_pep695_literal_defaults_to_enum(self, decl_base, type_, in_map): """test #11305.""" - class Foo(decl_base): - __tablename__ = "footable" + def declare(): + class Foo(decl_base): + __tablename__ = "footable" - id: Mapped[int] = mapped_column(primary_key=True) - status: Mapped[_Literal695] - r2: Mapped[_Recursive695_2] + id: Mapped[int] = mapped_column(primary_key=True) + if type_.recursive: + status: Mapped[_RecursiveLiteral695] # noqa: F821 + elif type_.literal: + status: Mapped[_Literal695] # noqa: F821 + elif type_.not_literal: + status: Mapped[_StrPep695] # noqa: F821 + else: + type_.fail() + + return Foo - for col in (Foo.__table__.c.status, Foo.__table__.c.r2): + if in_map: + decl_base.registry.update_type_annotation_map( + { + _Literal695: Enum(enum.Enum), # noqa: F821 + _RecursiveLiteral695: Enum(enum.Enum), # noqa: F821 + _StrPep695: Enum(enum.Enum), # noqa: F821 + } + ) + if type_.recursive: + with expect_deprecated( + "Mapping recursive TypeAliasType '.+' that resolve to " + "literal to generate an Enum is deprecated. SQLAlchemy " + "2.1 will not support this use case. Please avoid using " + "recursing TypeAliasType", + ): + Foo = declare() + elif type_.literal: + Foo = declare() + else: + with expect_raises_message( + exc.ArgumentError, + "Can't associate TypeAliasType '.+' to an Enum " + "since it's not a direct alias of a Literal. Only " + "aliases in this form `type my_alias = Literal.'a', " + "'b'.` are supported when generating Enums.", + ): + declare() + return + else: + with expect_deprecated( + "Matching the provided TypeAliasType '.*' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + Foo = declare() + col = Foo.__table__.c.status + if in_map and not type_.not_literal: is_true(isinstance(col.type, Enum)) eq_(col.type.enums, ["to-do", "in-progress", "done"]) is_(col.type.native_enum, False) + else: + is_true(isinstance(col.type, String)) @testing.requires.python38 def test_typing_literal_identity(self, decl_base): @@ -1224,6 +1372,33 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): eq_(MyClass.__table__.c.data_four.type.length, 150) is_false(MyClass.__table__.c.data_four.nullable) + def test_newtype_missing_from_map(self, decl_base): + # anno only: global str50 + + str50 = NewType("str50", str) + + if compat.py310: + text = ".*str50" + else: + # NewTypes before 3.10 had a very bad repr + # .new_type at 0x...> + text = ".*NewType.*" + + with expect_deprecated( + f"Matching the provided NewType '{text}' on its " + "resolved value without matching it in the " + "type_annotation_map is deprecated; add this type to the " + "type_annotation_map to allow it to match explicitly.", + ): + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + data_one: Mapped[str50] + + is_true(isinstance(MyClass.data_one.type, String)) + def test_extract_base_type_from_pep593( self, decl_base: Type[DeclarativeBase] ): @@ -1715,39 +1890,40 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): else: is_(getattr(Element.__table__.c.data, paramname), override_value) - @testing.variation("union", ["union", "pep604"]) - @testing.variation("typealias", ["legacy", "pep695"]) - def test_unions(self, union, typealias): + @testing.variation( + "union", + [ + "union", + ("pep604", requires.python310), + "union_null", + ("pep604_null", requires.python310), + ], + ) + def test_unions(self, union): + # anno only: global UnionType our_type = Numeric(10, 2) if union.union: UnionType = Union[float, Decimal] + elif union.union_null: + UnionType = Union[float, Decimal, None] elif union.pep604: - if not compat.py310: - skip_test("Required Python 3.10") UnionType = float | Decimal + elif union.pep604_null: + UnionType = float | Decimal | None else: union.fail() - if typealias.legacy: - UnionTypeAlias = UnionType - elif typealias.pep695: - # same as type UnionTypeAlias = UnionType - UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType) - else: - typealias.fail() - class Base(DeclarativeBase): - type_annotation_map = {UnionTypeAlias: our_type} + type_annotation_map = {UnionType: our_type} class User(Base): __tablename__ = "users" - __table__: Table id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[Union[float, Decimal]] = mapped_column() - reverse_data: Mapped[Union[Decimal, float]] = mapped_column() + data: Mapped[Union[float, Decimal]] + reverse_data: Mapped[Union[Decimal, float]] optional_data: Mapped[Optional[Union[float, Decimal]]] = ( mapped_column() @@ -1764,6 +1940,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): mapped_column() ) + refer_union: Mapped[UnionType] + refer_union_optional: Mapped[Optional[UnionType]] + float_data: Mapped[float] = mapped_column() decimal_data: Mapped[Decimal] = mapped_column() @@ -1779,65 +1958,54 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): mapped_column() ) - if compat.py312: - MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal) - pep695_data: Mapped[MyTypeAlias] = mapped_column() - - is_(User.__table__.c.data.type, our_type) - is_false(User.__table__.c.data.nullable) - is_(User.__table__.c.reverse_data.type, our_type) - is_(User.__table__.c.optional_data.type, our_type) - is_true(User.__table__.c.optional_data.nullable) + info = [ + ("data", False), + ("reverse_data", False), + ("optional_data", True), + ("reverse_optional_data", True), + ("reverse_u_optional_data", True), + ("refer_union", "null" in union.name), + ("refer_union_optional", True), + ] + if compat.py310: + info += [ + ("pep604_data", False), + ("pep604_reverse", False), + ("pep604_optional", True), + ("pep604_data_fwd", False), + ("pep604_reverse_fwd", False), + ("pep604_optional_fwd", True), + ] - is_(User.__table__.c.reverse_optional_data.type, our_type) - is_(User.__table__.c.reverse_u_optional_data.type, our_type) - is_true(User.__table__.c.reverse_optional_data.nullable) - is_true(User.__table__.c.reverse_u_optional_data.nullable) + for name, nullable in info: + col = User.__table__.c[name] + is_(col.type, our_type, name) + is_(col.nullable, nullable, name) is_true(isinstance(User.__table__.c.float_data.type, Float)) + ne_(User.__table__.c.float_data.type, our_type) - is_not(User.__table__.c.decimal_data.type, our_type) + is_true(isinstance(User.__table__.c.decimal_data.type, Numeric)) + ne_(User.__table__.c.decimal_data.type, our_type) - if compat.py310: - for suffix in ("", "_fwd"): - data_col = User.__table__.c[f"pep604_data{suffix}"] - reverse_col = User.__table__.c[f"pep604_reverse{suffix}"] - optional_col = User.__table__.c[f"pep604_optional{suffix}"] - is_(data_col.type, our_type) - is_false(data_col.nullable) - is_(reverse_col.type, our_type) - is_false(reverse_col.nullable) - is_(optional_col.type, our_type) - is_true(optional_col.nullable) - - if compat.py312: - is_(User.__table__.c.pep695_data.type, our_type) - - @testing.variation("union", ["union", "pep604"]) + @testing.variation( + "union", + [ + "union", + ("pep604", requires.python310), + ("pep695", requires.python312), + ], + ) def test_optional_in_annotation_map(self, union): - """SQLAlchemy's behaviour is clear: an optional type means the column - is inferred as nullable. Some types which a user may want to put in the - type annotation map are already optional. JSON is a good example - because without any constraint, the type can be None via JSON null or - SQL NULL. - - By permitting optional types in the type annotation map, everything - just works, and mapped_column(nullable=False) is available if desired. - - See issue #11370 - """ + """See issue #11370""" class Base(DeclarativeBase): if union.union: - type_annotation_map = { - _Json: JSON, - } + type_annotation_map = {_Json: JSON} elif union.pep604: - if not compat.py310: - skip_test("Requires Python 3.10+") - type_annotation_map = { - _JsonPep604: JSON, - } + type_annotation_map = {_JsonPep604: JSON} + elif union.pep695: + type_annotation_map = {_JsonPep695: JSON} # noqa: F821 else: union.fail() @@ -1849,10 +2017,13 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): json1: Mapped[_Json] json2: Mapped[_Json] = mapped_column(nullable=False) elif union.pep604: - if not compat.py310: - skip_test("Requires Python 3.10+") json1: Mapped[_JsonPep604] json2: Mapped[_JsonPep604] = mapped_column(nullable=False) + elif union.pep695: + json1: Mapped[_JsonPep695] # noqa: F821 + json2: Mapped[_JsonPep695] = mapped_column( # noqa: F821 + nullable=False + ) else: union.fail() diff --git a/tools/format_docs_code.py b/tools/format_docs_code.py index 3a06ac9f27..a3b6965c86 100644 --- a/tools/format_docs_code.py +++ b/tools/format_docs_code.py @@ -12,6 +12,7 @@ that it extracts from the documentation. from argparse import ArgumentParser from argparse import RawDescriptionHelpFormatter from collections.abc import Iterator +import dataclasses from functools import partial from itertools import chain from pathlib import Path @@ -33,6 +34,8 @@ ignore_paths = ( re.compile(r"build"), ) +CUSTOM_TARGET_VERSIONS = {"declarative_tables.rst": "PY312"} + class BlockLine(NamedTuple): line: str @@ -66,6 +69,12 @@ def _format_block( code = "\n".join(l.code for l in input_block) mode = PYTHON_BLACK_MODE if is_python_file else RST_BLACK_MODE + custom_target = CUSTOM_TARGET_VERSIONS.get(Path(file).name) + if custom_target: + mode = dataclasses.replace( + mode, target_versions={TargetVersion[custom_target]} + ) + try: formatted = format_str(code, mode=mode) except Exception as e: