]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
General improvement on annotated declarative
authorFederico Caselli <cfederico87@gmail.com>
Tue, 19 Nov 2024 22:12:51 +0000 (23:12 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 12 Dec 2024 21:10:06 +0000 (22:10 +0100)
Fix issue that resulted in inconsistent handing of unions
depending on how they were declared

Consistently support TypeAliasType. This has required a revision
of the implementation added in #11305 to have a consistent
behavior.

References: #11944
References: #11955
References: #11305
Change-Id: Iffc34fd42b9769f73ddb4331bd59b6b37391635d
(cherry picked from commit e6b0b421d60ecf660cf3872f3f32dd2b7a739b59)

16 files changed:
doc/build/changelog/unreleased_20/11944.rst [new file with mode: 0644]
doc/build/changelog/unreleased_20/11955.rst [new file with mode: 0644]
doc/build/orm/declarative_tables.rst
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/util/typing.py
test/base/test_typing_utils.py [new file with mode: 0644]
test/base/test_utils.py
test/orm/declarative/test_tm_future_annotations.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py
tools/format_docs_code.py

diff --git a/doc/build/changelog/unreleased_20/11944.rst b/doc/build/changelog/unreleased_20/11944.rst
new file mode 100644 (file)
index 0000000..e746918
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11944
+
+    Fixed bug in how type unions were handled that made the behavior
+    of ``a | b`` different from ``Union[a, b]``.
diff --git a/doc/build/changelog/unreleased_20/11955.rst b/doc/build/changelog/unreleased_20/11955.rst
new file mode 100644 (file)
index 0000000..eeeb2bc
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11955
+
+    Consistently handle ``TypeAliasType`` (defined in PEP 695) obtained with the
+    ``type X = int`` syntax introduced in python 3.12.
+    Now in all cases one such alias must be explicitly added to the type map for
+    it to be usable inside ``Mapped[]``.
+    This change also revises the approach added in :ticket:`11305`, now requiring
+    the ``TypeAliasType`` to be added to the type map.
+    Documentation on how unions and type alias types are handled by SQLAlchemy
+    has been added in the :ref:`orm_declarative_mapped_column_type_map` section
+    of the documentation.
index b2c91981b3ecc0d84013461a74f46df040dc32e9..4bb4237ac1757088af352b781b2c48fbcf10aed5 100644 (file)
@@ -316,9 +316,8 @@ the registry and Declarative base could be configured as::
 
     import datetime
 
-    from sqlalchemy import BIGINT, Integer, NVARCHAR, String, TIMESTAMP
-    from sqlalchemy.orm import DeclarativeBase
-    from sqlalchemy.orm import Mapped, mapped_column, registry
+    from sqlalchemy import BIGINT, NVARCHAR, String, TIMESTAMP
+    from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 
 
     class Base(DeclarativeBase):
@@ -369,6 +368,59 @@ while still being able to use succinct annotation-only :func:`_orm.mapped_column
 configurations.  There are two more levels of Python-type configurability
 available beyond this, described in the next two sections.
 
+Union types inside the Type Map
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+SQLAlchemy supports mapping union types inside the type map to allow
+mapping database types that can support multiple Python types,
+such as :class:`_types.JSON` or :class:`_postgresql.JSONB`::
+
+    from sqlalchemy import JSON
+    from sqlalchemy.dialects import postgresql
+    from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+    from sqlalchemy.schema import CreateTable
+
+    json_list = list[int] | list[str]
+    json_scalar = float | str | bool | None
+
+
+    class Base(DeclarativeBase):
+        type_annotation_map = {
+            json_list: postgresql.JSONB,
+            json_scalar: JSON,
+        }
+
+
+    class SomeClass(Base):
+        __tablename__ = "some_table"
+
+        id: Mapped[int] = mapped_column(primary_key=True)
+        list_col: Mapped[list[str] | list[int]]
+        scalar_col: Mapped[json_scalar]
+        scalar_col_not_null: Mapped[str | float | bool]
+
+Using the union directly inside ``Mapped`` or creating a new one with the same
+effective types has the same behavior: ``list_col`` will be matched to the
+``json_list`` union even if it does not reference it directly (the order of the
+types also does not matter).
+If the union added to the type map includes ``None``, it will be ignored
+when matching the ``Mapped`` type since ``None`` is only used to decide
+the column nullability. It follows that both ``scalar_col`` and
+``scalar_col_not_null`` will match the ``json_scalar`` union.
+
+The CREATE TABLE statement of the table created above is as follows:
+
+.. sourcecode:: pycon+sql
+
+    >>> print(CreateTable(SomeClass.__table__).compile(dialect=postgresql.dialect()))
+    {printsql}CREATE TABLE some_table (
+        id SERIAL NOT NULL,
+        list_col JSONB NOT NULL,
+        scalar_col JSON,
+        scalar_col_not_null JSON NOT NULL,
+        PRIMARY KEY (id)
+    )
+
 .. _orm_declarative_mapped_column_type_map_pep593:
 
 Mapping Multiple Type Configurations to Python Types
@@ -458,6 +510,96 @@ us a wide degree of flexibility, the next section illustrates a second
 way in which ``Annotated`` may be used with Declarative that is even
 more open ended.
 
+Support for Type Alias Types (defined by PEP 695) and NewType
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The typing module allows an user to create "new types" using ``typing.NewType``::
+
+    from typing import NewType
+
+    nstr30 = NewType("nstr30", str)
+    nstr50 = NewType("nstr50", str)
+
+These are considered as different by the type checkers and by python::
+
+    >>> print(str == nstr30, nstr50 == nstr30, nstr30 == NewType("nstr30", str))
+    False False False
+
+Another similar feature was added in Python 3.12 to create aliases,
+using a new syntax to define ``typing.TypeAliasType``::
+
+    type SmallInt = int
+    type BigInt = int
+    type JsonScalar = str | float | bool | None
+
+Like ``typing.NewType``, these are treated by python as different, meaning that they are
+not equal between each other even if they represent the same Python type.
+In the example above, ``SmallInt`` and ``BigInt`` are not considered equal even
+if they both are aliases of the python type ``int``::
+
+    >>> print(SmallInt == BigInt)
+    False
+
+SQLAlchemy supports using ``typing.NewType`` and ``typing.TypeAliasType``
+in the ``type_annotation_map``. They can be used to associate the same python type
+to different :class:`_types.TypeEngine` types, similarly
+to ``typing.Annotated``::
+
+    from sqlalchemy import SmallInteger, BigInteger, JSON, String
+    from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+    from sqlalchemy.schema import CreateTable
+
+
+    class TABase(DeclarativeBase):
+        type_annotation_map = {
+            nstr30: String(30),
+            nstr50: String(50),
+            SmallInt: SmallInteger,
+            BigInteger: BigInteger,
+            JsonScalar: JSON,
+        }
+
+
+    class SomeClass(TABase):
+        __tablename__ = "some_table"
+
+        id: Mapped[int] = mapped_column(primary_key=True)
+        normal_str: Mapped[str]
+
+        short_str: Mapped[nstr30]
+        long_str: Mapped[nstr50]
+
+        small_int: Mapped[SmallInt]
+        big_int: Mapped[BigInteger]
+        scalar_col: Mapped[JsonScalar]
+
+a CREATE TABLE for the above mapping will illustrate the different variants
+of integer and string we've configured, and looks like:
+
+.. sourcecode:: pycon+sql
+
+    >>> print(CreateTable(SomeClass.__table__))
+    {printsql}CREATE TABLE some_table (
+        id INTEGER NOT NULL,
+        normal_str VARCHAR NOT NULL,
+        short_str VARCHAR(30) NOT NULL,
+        long_str VARCHAR(50) NOT NULL,
+        small_int SMALLINT NOT NULL,
+        big_int BIGINT NOT NULL,
+        scalar_col JSON,
+        PRIMARY KEY (id)
+    )
+
+Since the ``JsonScalar`` type includes ``None`` the columns is nullable, while
+``id`` and ``normal_str`` columns use the default mapping for their respective
+Python type.
+
+As mentioned above, since ``typing.NewType`` and ``typing.TypeAliasType`` are
+considered standalone types, they must be referenced directly inside ``Mapped``
+and must be added explicitly to the type map.
+Failing to do so will raise an error since SQLAlchemy does not know what
+SQL type to use.
+
 .. _orm_declarative_mapped_column_pep593:
 
 Mapping Whole Column Declarations to Python Types
@@ -743,6 +885,28 @@ appropriate settings, including default string length.   If a ``typing.Literal``
 that does not consist of only string values is passed, an informative
 error is raised.
 
+``typing.TypeAliasType`` can also be used to create enums, by assigning them
+to a ``typing.Literal`` of strings::
+
+    from typing import Literal
+
+    type Status = Literal["on", "off", "unknown"]
+
+Since this is a ``typing.TypeAliasType``, it represents a unique type object,
+so it must be placed in the ``type_annotation_map`` for it to be looked up
+successfully, keyed to the :class:`.Enum` type as follows::
+
+    import enum
+    import sqlalchemy
+
+
+    class Base(DeclarativeBase):
+        type_annotation_map = {Status: sqlalchemy.Enum(enum.Enum)}
+
+Since SQLAlchemy supports mapping different ``typing.TypeAliasType``
+objects that are otherwise structurally equivalent individually,
+these must be present in ``type_annotation_map`` to avoid ambiguity.
+
 Native Enums and Naming
 +++++++++++++++++++++++
 
index 718cf72516bd6773f4e3b7455886c1aa9fa9977a..a3b0ac21f0a20ecd0fea0472fbbee4464c3d066f 100644 (file)
@@ -14,7 +14,6 @@ import re
 import typing
 from typing import Any
 from typing import Callable
-from typing import cast
 from typing import ClassVar
 from typing import Dict
 from typing import FrozenSet
@@ -72,6 +71,7 @@ from ..sql.selectable import FromClause
 from ..util import hybridmethod
 from ..util import hybridproperty
 from ..util import typing as compat_typing
+from ..util import warn_deprecated
 from ..util.typing import CallableReference
 from ..util.typing import de_optionalize_union_types
 from ..util.typing import flatten_newtype
@@ -80,6 +80,7 @@ from ..util.typing import is_literal
 from ..util.typing import is_newtype
 from ..util.typing import is_pep695
 from ..util.typing import Literal
+from ..util.typing import LITERAL_TYPES
 from ..util.typing import Self
 
 if TYPE_CHECKING:
@@ -1232,40 +1233,27 @@ class registry:
         )
 
     def _resolve_type(
-        self, python_type: _MatchedOnType
+        self, python_type: _MatchedOnType, _do_fallbacks: bool = True
     ) -> Optional[sqltypes.TypeEngine[Any]]:
-
-        python_type_to_check = python_type
-        while is_pep695(python_type_to_check):
-            python_type_to_check = python_type_to_check.__value__
-
-        check_is_pt = python_type is python_type_to_check
-
         python_type_type: Type[Any]
         search: Iterable[Tuple[_MatchedOnType, Type[Any]]]
 
-        if is_generic(python_type_to_check):
-            if is_literal(python_type_to_check):
-                python_type_type = cast("Type[Any]", python_type_to_check)
+        if is_generic(python_type):
+            if is_literal(python_type):
+                python_type_type = python_type  # type: ignore[assignment]
 
-                search = (  # type: ignore[assignment]
+                search = (
                     (python_type, python_type_type),
-                    (Literal, python_type_type),
+                    *((lt, python_type_type) for lt in LITERAL_TYPES),  # type: ignore[arg-type] # noqa: E501
                 )
             else:
-                python_type_type = python_type_to_check.__origin__
+                python_type_type = python_type.__origin__
                 search = ((python_type, python_type_type),)
-        elif is_newtype(python_type_to_check):
-            python_type_type = flatten_newtype(python_type_to_check)
-            search = ((python_type, python_type_type),)
-        elif isinstance(python_type_to_check, type):
-            python_type_type = python_type_to_check
-            search = (
-                (pt if check_is_pt else python_type, pt)
-                for pt in python_type_type.__mro__
-            )
+        elif isinstance(python_type, type):
+            python_type_type = python_type
+            search = ((pt, pt) for pt in python_type_type.__mro__)
         else:
-            python_type_type = python_type_to_check  # type: ignore[assignment]
+            python_type_type = python_type  # type: ignore[assignment]
             search = ((python_type, python_type_type),)
 
         for pt, flattened in search:
@@ -1290,6 +1278,39 @@ class registry:
                 if resolved_sql_type is not None:
                     return resolved_sql_type
 
+        # 2.0 fallbacks
+        if _do_fallbacks:
+            python_type_to_check: Any = None
+            kind = None
+            if is_pep695(python_type):
+                # NOTE: assume there aren't type alias types of new types.
+                python_type_to_check = python_type
+                while is_pep695(python_type_to_check):
+                    python_type_to_check = python_type_to_check.__value__
+                python_type_to_check = de_optionalize_union_types(
+                    python_type_to_check
+                )
+                kind = "TypeAliasType"
+            if is_newtype(python_type):
+                python_type_to_check = flatten_newtype(python_type)
+                kind = "NewType"
+
+            if python_type_to_check is not None:
+                res_after_fallback = self._resolve_type(
+                    python_type_to_check, False
+                )
+                if res_after_fallback is not None:
+                    assert kind is not None
+                    warn_deprecated(
+                        f"Matching the provided {kind} '{python_type}' on "
+                        "its resolved value without matching it in the "
+                        "type_annotation_map is deprecated; add this type to "
+                        "the type_annotation_map to allow it to match "
+                        "explicitly.",
+                        "2.0",
+                    )
+                    return res_after_fallback
+
         return None
 
     @property
index b069d23c0f51209f48b8ede4fc3cbbd5f8fe8881..aa64eaa666780866ea611a022a2f75d9bb3acdd1 100644 (file)
@@ -65,11 +65,11 @@ from ..sql.schema import Column
 from ..sql.schema import Table
 from ..util import topological
 from ..util.typing import _AnnotationScanType
+from ..util.typing import get_args
 from ..util.typing import is_fwd_ref
 from ..util.typing import is_literal
 from ..util.typing import Protocol
 from ..util.typing import TypedDict
-from ..util.typing import typing_get_args
 
 if TYPE_CHECKING:
     from ._typing import _ClassDict
@@ -1319,7 +1319,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         extracted_mapped_annotation, mapped_container = extracted
 
         if attr_value is None and not is_literal(extracted_mapped_annotation):
-            for elem in typing_get_args(extracted_mapped_annotation):
+            for elem in get_args(extracted_mapped_annotation):
                 if isinstance(elem, str) or is_fwd_ref(
                     elem, check_generic=True
                 ):
index faf287cce6c2aeccc3f4fffb67e673ecf12d5758..4e07050a1d6c594e564f73b7bb08ae0237e7ed75 100644 (file)
@@ -53,9 +53,10 @@ from .. import util
 from ..sql import expression
 from ..sql import operators
 from ..sql.elements import BindParameter
+from ..util.typing import get_args
 from ..util.typing import is_fwd_ref
 from ..util.typing import is_pep593
-from ..util.typing import typing_get_args
+
 
 if typing.TYPE_CHECKING:
     from ._typing import _InstanceDict
@@ -364,7 +365,7 @@ class CompositeProperty(
         argument = extracted_mapped_annotation
 
         if is_pep593(argument):
-            argument = typing_get_args(argument)[0]
+            argument = get_args(argument)[0]
 
         if argument and self.composite_class is None:
             if isinstance(argument, str) or is_fwd_ref(
index b6fb3d43e318b1716da38f3521af3cd613c02c21..96ae9d7f82ad1594e9201b5033c6c2e22e6cff73 100644 (file)
@@ -55,13 +55,13 @@ from ..sql.schema import Column
 from ..sql.schema import SchemaConst
 from ..sql.type_api import TypeEngine
 from ..util.typing import de_optionalize_union_types
+from ..util.typing import get_args
+from ..util.typing import includes_none
 from ..util.typing import is_fwd_ref
-from ..util.typing import is_optional_union
 from ..util.typing import is_pep593
 from ..util.typing import is_pep695
 from ..util.typing import is_union
 from ..util.typing import Self
-from ..util.typing import typing_get_args
 
 if TYPE_CHECKING:
     from ._typing import _IdentityKeyType
@@ -752,38 +752,36 @@ class MappedColumn(
                 cls, argument, originating_module
             )
 
-        nullable = is_optional_union(argument)
+        nullable = includes_none(argument)
 
         if not self._has_nullable:
             self.column.nullable = nullable
 
         our_type = de_optionalize_union_types(argument)
 
-        use_args_from = None
-
-        our_original_type = our_type
-
-        if is_pep695(our_type):
-            our_type = our_type.__value__
+        find_mapped_in: Tuple[Any, ...] = ()
+        our_type_is_pep593 = False
+        raw_pep_593_type = None
 
         if is_pep593(our_type):
             our_type_is_pep593 = True
 
-            pep_593_components = typing_get_args(our_type)
+            pep_593_components = get_args(our_type)
             raw_pep_593_type = pep_593_components[0]
-            if is_optional_union(raw_pep_593_type):
+            if nullable:
                 raw_pep_593_type = de_optionalize_union_types(raw_pep_593_type)
-
-                nullable = True
-                if not self._has_nullable:
-                    self.column.nullable = nullable
-            for elem in pep_593_components[1:]:
-                if isinstance(elem, MappedColumn):
-                    use_args_from = elem
-                    break
+            find_mapped_in = pep_593_components[1:]
+        elif is_pep695(argument) and is_pep593(argument.__value__):
+            # do not support nested annotation inside unions ets
+            find_mapped_in = get_args(argument.__value__)[1:]
+
+        use_args_from: Optional[MappedColumn[Any]]
+        for elem in find_mapped_in:
+            if isinstance(elem, MappedColumn):
+                use_args_from = elem
+                break
         else:
-            our_type_is_pep593 = False
-            raw_pep_593_type = None
+            use_args_from = None
 
         if use_args_from is not None:
             if (
@@ -857,10 +855,11 @@ class MappedColumn(
         if sqltype._isnull and not self.column.foreign_keys:
             new_sqltype = None
 
+            checks: List[Any]
             if our_type_is_pep593:
-                checks = [our_original_type, raw_pep_593_type]
+                checks = [our_type, raw_pep_593_type]
             else:
-                checks = [our_original_type]
+                checks = [our_type]
 
             for check_type in checks:
                 new_sqltype = registry._resolve_type(check_type)
index dbfa6d5f1b8e1414567f97f25638201a865ab3b1..11b6ac2c1ca71c84b98ae4317f58166977b4e158 100644 (file)
@@ -91,10 +91,10 @@ from ..util.typing import (
 )
 from ..util.typing import eval_name_only as _eval_name_only
 from ..util.typing import fixup_container_fwd_refs
+from ..util.typing import get_origin
 from ..util.typing import is_origin_of_cls
 from ..util.typing import Literal
 from ..util.typing import Protocol
-from ..util.typing import typing_get_origin
 
 if typing.TYPE_CHECKING:
     from ._typing import _EntityType
@@ -123,7 +123,7 @@ if typing.TYPE_CHECKING:
     from ..sql.selectable import Selectable
     from ..sql.visitors import anon_map
     from ..util.typing import _AnnotationScanType
-    from ..util.typing import ArgsTypeProcotol
+    from ..util.typing import ArgsTypeProtocol
 
 _T = TypeVar("_T", bound=Any)
 
@@ -177,7 +177,7 @@ class _DeStringifyUnionElements(Protocol):
     def __call__(
         self,
         cls: Type[Any],
-        annotation: ArgsTypeProcotol,
+        annotation: ArgsTypeProtocol,
         originating_module: str,
         *,
         str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
@@ -1543,7 +1543,7 @@ GenericAlias = type(List[Any])
 def _inspect_generic_alias(
     class_: Type[_O],
 ) -> Optional[Mapper[_O]]:
-    origin = cast("Type[_O]", typing_get_origin(class_))
+    origin = cast("Type[_O]", get_origin(class_))
     return _inspect_mc(origin)
 
 
index f16db640664e30148afcdd48faf27a1acadbaef8..a7d140ec6bdca88f6879a13a3cbe4cef57b99316 100644 (file)
@@ -59,9 +59,11 @@ from .. import util
 from ..engine import processors
 from ..util import langhelpers
 from ..util import OrderedDict
+from ..util import warn_deprecated
+from ..util.typing import get_args
 from ..util.typing import is_literal
+from ..util.typing import is_pep695
 from ..util.typing import Literal
-from ..util.typing import typing_get_args
 
 if TYPE_CHECKING:
     from ._typing import _ColumnExpressionArgument
@@ -1511,6 +1513,19 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
 
         native_enum = None
 
+        def process_literal(pt):
+            # for a literal, where we need to get its contents, parse it out.
+            enum_args = get_args(pt)
+            bad_args = [arg for arg in enum_args if not isinstance(arg, str)]
+            if bad_args:
+                raise exc.ArgumentError(
+                    f"Can't create string-based Enum datatype from non-string "
+                    f"values: {', '.join(repr(x) for x in bad_args)}.  Please "
+                    f"provide an explicit Enum datatype for this Python type"
+                )
+            native_enum = False
+            return enum_args, native_enum
+
         if not we_are_generic_form and python_type is matched_on:
             # if we have enumerated values, and the incoming python
             # type is exactly the one that matched in the type map,
@@ -1519,16 +1534,32 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
             enum_args = self._enums_argument
 
         elif is_literal(python_type):
-            # for a literal, where we need to get its contents, parse it out.
-            enum_args = typing_get_args(python_type)
-            bad_args = [arg for arg in enum_args if not isinstance(arg, str)]
-            if bad_args:
+            enum_args, native_enum = process_literal(python_type)
+        elif is_pep695(python_type):
+            value = python_type.__value__
+            if is_pep695(value):
+                new_value = value
+                while is_pep695(new_value):
+                    new_value = new_value.__value__
+                if is_literal(new_value):
+                    value = new_value
+                    warn_deprecated(
+                        f"Mapping recursive TypeAliasType '{python_type}' "
+                        "that resolve to literal to generate an Enum is "
+                        "deprecated. SQLAlchemy 2.1 will not support this "
+                        "use case. Please avoid using recursing "
+                        "TypeAliasType.",
+                        "2.0",
+                    )
+            if not is_literal(value):
                 raise exc.ArgumentError(
-                    f"Can't create string-based Enum datatype from non-string "
-                    f"values: {', '.join(repr(x) for x in bad_args)}.  Please "
-                    f"provide an explicit Enum datatype for this Python type"
+                    f"Can't associate TypeAliasType '{python_type}' to an "
+                    "Enum since it's not a direct alias of a Literal. Only "
+                    "aliases in this form `type my_alias = Literal['a', "
+                    "'b']` are supported when generating Enums."
                 )
-            native_enum = False
+            enum_args, native_enum = process_literal(value)
+
         elif isinstance(python_type, type) and issubclass(
             python_type, enum.Enum
         ):
index bd1ebd4c01380e5b5351f0a1a093e3154ea06da4..645a41a24068614b65e5e54c5cf67f48772838ec 100644 (file)
@@ -9,6 +9,7 @@
 from __future__ import annotations
 
 import builtins
+from collections import deque
 import collections.abc as collections_abc
 import re
 import sys
@@ -54,6 +55,7 @@ if True:  # zimports removes the tailing comments
     from typing_extensions import TypeGuard as TypeGuard  # 3.10
     from typing_extensions import Self as Self  # 3.11
     from typing_extensions import TypeAliasType as TypeAliasType  # 3.12
+    from typing_extensions import Never as Never  # 3.11
 
 _T = TypeVar("_T", bound=Any)
 _KT = TypeVar("_KT")
@@ -65,9 +67,9 @@ _VT_co = TypeVar("_VT_co", covariant=True)
 if compat.py38:
     # typing_extensions.Literal is different from typing.Literal until
     # Python 3.10.1
-    _LITERAL_TYPES = frozenset([typing.Literal, Literal])
+    LITERAL_TYPES = frozenset([typing.Literal, Literal])
 else:
-    _LITERAL_TYPES = frozenset([Literal])
+    LITERAL_TYPES = frozenset([Literal])
 
 
 if compat.py310:
@@ -79,16 +81,13 @@ else:
 
 NoneFwd = ForwardRef("None")
 
-typing_get_args = get_args
-typing_get_origin = get_origin
-
 
 _AnnotationScanType = Union[
     Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
 ]
 
 
-class ArgsTypeProcotol(Protocol):
+class ArgsTypeProtocol(Protocol):
     """protocol for types that have ``__args__``
 
     there's no public interface for this AFAIK
@@ -209,7 +208,7 @@ def fixup_container_fwd_refs(
 
     if (
         is_generic(type_)
-        and typing_get_origin(type_)
+        and get_origin(type_)
         in (
             dict,
             set,
@@ -229,11 +228,11 @@ def fixup_container_fwd_refs(
         )
     ):
         # compat with py3.10 and earlier
-        return typing_get_origin(type_).__class_getitem__(  # type: ignore
+        return get_origin(type_).__class_getitem__(  # type: ignore
             tuple(
                 [
                     ForwardRef(elem) if isinstance(elem, str) else elem
-                    for elem in typing_get_args(type_)
+                    for elem in get_args(type_)
                 ]
             )
         )
@@ -332,7 +331,7 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
 
 def de_stringify_union_elements(
     cls: Type[Any],
-    annotation: ArgsTypeProcotol,
+    annotation: ArgsTypeProtocol,
     originating_module: str,
     locals_: Mapping[str, Any],
     *,
@@ -352,8 +351,8 @@ def de_stringify_union_elements(
     )
 
 
-def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
-    return type_ is not None and typing_get_origin(type_) is Annotated
+def is_pep593(type_: Optional[Any]) -> bool:
+    return type_ is not None and get_origin(type_) is Annotated
 
 
 def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
@@ -362,8 +361,8 @@ def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
     )
 
 
-def is_literal(type_: _AnnotationScanType) -> bool:
-    return get_origin(type_) in _LITERAL_TYPES
+def is_literal(type_: Any) -> bool:
+    return get_origin(type_) in LITERAL_TYPES
 
 
 def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
@@ -389,6 +388,43 @@ def flatten_newtype(type_: NewType) -> Type[Any]:
     return super_type  # type: ignore[return-value]
 
 
+def pep695_values(type_: _AnnotationScanType) -> Set[Any]:
+    """Extracts the value from a TypeAliasType, recursively exploring unions
+    and inner TypeAliasType to flatten them into a single set.
+
+    Forward references are not evaluated, so no recursive exploration happens
+    into them.
+    """
+    _seen = set()
+
+    def recursive_value(type_):
+        if type_ in _seen:
+            # recursion are not supported (at least it's flagged as
+            # an error by pyright). Just avoid infinite loop
+            return type_
+        _seen.add(type_)
+        if not is_pep695(type_):
+            return type_
+        value = type_.__value__
+        if not is_union(value):
+            return value
+        return [recursive_value(t) for t in value.__args__]
+
+    res = recursive_value(type_)
+    if isinstance(res, list):
+        types = set()
+        stack = deque(res)
+        while stack:
+            t = stack.popleft()
+            if isinstance(t, list):
+                stack.extend(t)
+            else:
+                types.add(None if t in {NoneType, NoneFwd} else t)
+        return types
+    else:
+        return {res}
+
+
 def is_fwd_ref(
     type_: _AnnotationScanType, check_generic: bool = False
 ) -> TypeGuard[ForwardRef]:
@@ -422,13 +458,10 @@ def de_optionalize_union_types(
 
     """
 
-    while is_pep695(type_):
-        type_ = type_.__value__
-
     if is_fwd_ref(type_):
-        return de_optionalize_fwd_ref_union_types(type_)
+        return _de_optionalize_fwd_ref_union_types(type_, False)
 
-    elif is_optional(type_):
+    elif is_union(type_) and includes_none(type_):
         typ = set(type_.__args__)
 
         typ.discard(NoneType)
@@ -440,9 +473,21 @@ def de_optionalize_union_types(
         return type_
 
 
-def de_optionalize_fwd_ref_union_types(
-    type_: ForwardRef,
-) -> _AnnotationScanType:
+@overload
+def _de_optionalize_fwd_ref_union_types(
+    type_: ForwardRef, return_has_none: Literal[True]
+) -> bool: ...
+
+
+@overload
+def _de_optionalize_fwd_ref_union_types(
+    type_: ForwardRef, return_has_none: Literal[False]
+) -> _AnnotationScanType: ...
+
+
+def _de_optionalize_fwd_ref_union_types(
+    type_: ForwardRef, return_has_none: bool
+) -> Union[_AnnotationScanType, bool]:
     """return the non-optional type for Optional[], Union[None, ...], x|None,
     etc. without de-stringifying forward refs.
 
@@ -454,47 +499,77 @@ def de_optionalize_fwd_ref_union_types(
 
     mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
     if mm:
-        if mm.group(1) == "Optional":
-            return ForwardRef(mm.group(2))
-        elif mm.group(1) == "Union":
-            elements = re.split(r",\s*", mm.group(2))
-            return make_union_type(
-                *[ForwardRef(elem) for elem in elements if elem != "None"]
-            )
+        g1 = mm.group(1).split(".")[-1]
+        if g1 == "Optional":
+            return True if return_has_none else ForwardRef(mm.group(2))
+        elif g1 == "Union":
+            if "[" in mm.group(2):
+                # cases like "Union[Dict[str, int], int, None]"
+                elements: list[str] = []
+                current: list[str] = []
+                ignore_comma = 0
+                for char in mm.group(2):
+                    if char == "[":
+                        ignore_comma += 1
+                    elif char == "]":
+                        ignore_comma -= 1
+                    elif ignore_comma == 0 and char == ",":
+                        elements.append("".join(current).strip())
+                        current.clear()
+                        continue
+                    current.append(char)
+            else:
+                elements = re.split(r",\s*", mm.group(2))
+            parts = [ForwardRef(elem) for elem in elements if elem != "None"]
+            if return_has_none:
+                return len(elements) != len(parts)
+            else:
+                return make_union_type(*parts) if parts else Never  # type: ignore[return-value] # noqa: E501
         else:
-            return type_
+            return False if return_has_none else type_
 
     pipe_tokens = re.split(r"\s*\|\s*", annotation)
-    if "None" in pipe_tokens:
-        return ForwardRef("|".join(p for p in pipe_tokens if p != "None"))
+    has_none = "None" in pipe_tokens
+    if return_has_none:
+        return has_none
+    if has_none:
+        anno_str = "|".join(p for p in pipe_tokens if p != "None")
+        return ForwardRef(anno_str) if anno_str else Never  # type: ignore[return-value] # noqa: E501
 
     return type_
 
 
 def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
-    """Make a Union type.
+    """Make a Union type."""
+    return Union.__getitem__(types)  # type: ignore
 
-    This is needed by :func:`.de_optionalize_union_types` which removes
-    ``NoneType`` from a ``Union``.
 
-    """
-    return cast(Any, Union).__getitem__(types)  # type: ignore
-
-
-def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
-    return is_origin_of(
-        type_,
-        "Optional",
-        "Union",
-        "UnionType",
-    )
+def includes_none(type_: Any) -> bool:
+    """Returns if the type annotation ``type_`` allows ``None``.
 
-
-def is_optional_union(type_: Any) -> bool:
-    return is_optional(type_) and NoneType in typing_get_args(type_)
-
-
-def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
+    This function supports:
+    * forward refs
+    * unions
+    * pep593 - Annotated
+    * pep695 - TypeAliasType (does not support looking into
+    fw reference of other pep695)
+    * NewType
+    * plain types like ``int``, ``None``, etc
+    """
+    if is_fwd_ref(type_):
+        return _de_optionalize_fwd_ref_union_types(type_, True)
+    if is_union(type_):
+        return any(includes_none(t) for t in get_args(type_))
+    if is_pep593(type_):
+        return includes_none(get_args(type_)[0])
+    if is_pep695(type_):
+        return any(includes_none(t) for t in pep695_values(type_))
+    if is_newtype(type_):
+        return includes_none(type_.__supertype__)
+    return type_ in (NoneFwd, NoneType, None)
+
+
+def is_union(type_: Any) -> TypeGuard[ArgsTypeProtocol]:
     return is_origin_of(type_, "Union", "UnionType")
 
 
@@ -504,7 +579,7 @@ def is_origin_of_cls(
     """return True if the given type has an __origin__ that shares a base
     with the given class"""
 
-    origin = typing_get_origin(type_)
+    origin = get_origin(type_)
     if origin is None:
         return False
 
@@ -517,7 +592,7 @@ def is_origin_of(
     """return True if the given type has an __origin__ with the given name
     and optional module."""
 
-    origin = typing_get_origin(type_)
+    origin = get_origin(type_)
     if origin is None:
         return False
 
@@ -607,6 +682,3 @@ class CallableReference(Generic[_FN]):
         def __set__(self, instance: Any, value: _FN) -> None: ...
 
         def __delete__(self, instance: Any) -> None: ...
-
-
-# $def ro_descriptor_reference(fn: Callable[])
diff --git a/test/base/test_typing_utils.py b/test/base/test_typing_utils.py
new file mode 100644 (file)
index 0000000..67e7bf4
--- /dev/null
@@ -0,0 +1,409 @@
+# NOTE: typing implementation is full of heuristic so unit test it to avoid
+# unexpected breakages.
+
+import typing
+
+import typing_extensions
+
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import requires
+from sqlalchemy.testing.assertions import eq_
+from sqlalchemy.testing.assertions import is_
+from sqlalchemy.util import py310
+from sqlalchemy.util import py311
+from sqlalchemy.util import py312
+from sqlalchemy.util import py38
+from sqlalchemy.util import typing as sa_typing
+
+TV = typing.TypeVar("TV")
+
+
+def union_types():
+    res = [typing.Union[int, str]]
+    if py310:
+        res.append(int | str)
+    return res
+
+
+def null_union_types():
+    res = [
+        typing.Optional[typing.Union[int, str]],
+        typing.Union[int, str, None],
+        typing.Union[int, str, "None"],
+    ]
+    if py310:
+        res.append(int | str | None)
+        res.append(typing.Optional[int | str])
+        res.append(typing.Union[int, str] | None)
+        res.append(typing.Optional[int] | str)
+    return res
+
+
+def make_fw_ref(anno: str) -> typing.ForwardRef:
+    return typing.Union[anno]
+
+
+TA_int = typing_extensions.TypeAliasType("TA_int", int)
+TA_union = typing_extensions.TypeAliasType("TA_union", typing.Union[int, str])
+TA_null_union = typing_extensions.TypeAliasType(
+    "TA_null_union", typing.Union[int, str, None]
+)
+TA_null_union2 = typing_extensions.TypeAliasType(
+    "TA_null_union2", typing.Union[int, str, "None"]
+)
+TA_null_union3 = typing_extensions.TypeAliasType(
+    "TA_null_union3", typing.Union[int, "typing.Union[None, bool]"]
+)
+TA_null_union4 = typing_extensions.TypeAliasType(
+    "TA_null_union4", typing.Union[int, "TA_null_union2"]
+)
+TA_union_ta = typing_extensions.TypeAliasType(
+    "TA_union_ta", typing.Union[TA_int, str]
+)
+TA_null_union_ta = typing_extensions.TypeAliasType(
+    "TA_null_union_ta", typing.Union[TA_null_union, float]
+)
+TA_list = typing_extensions.TypeAliasType(
+    "TA_list", typing.Union[int, str, typing.List["TA_list"]]
+)
+# these below not valid. Verify that it does not cause exceptions in any case
+TA_recursive = typing_extensions.TypeAliasType(
+    "TA_recursive", typing.Union["TA_recursive", str]
+)
+TA_null_recursive = typing_extensions.TypeAliasType(
+    "TA_null_recursive", typing.Union[TA_recursive, None]
+)
+TA_recursive_a = typing_extensions.TypeAliasType(
+    "TA_recursive_a", typing.Union["TA_recursive_b", int]
+)
+TA_recursive_b = typing_extensions.TypeAliasType(
+    "TA_recursive_b", typing.Union["TA_recursive_a", str]
+)
+
+
+def type_aliases():
+    return [
+        TA_int,
+        TA_union,
+        TA_null_union,
+        TA_null_union2,
+        TA_null_union3,
+        TA_null_union4,
+        TA_union_ta,
+        TA_null_union_ta,
+        TA_list,
+        TA_recursive,
+        TA_null_recursive,
+        TA_recursive_a,
+        TA_recursive_b,
+    ]
+
+
+NT_str = typing.NewType("NT_str", str)
+NT_null = typing.NewType("NT_null", None)
+# this below is not valid. Verify that it does not cause exceptions in any case
+NT_union = typing.NewType("NT_union", typing.Union[str, int])
+
+
+def new_types():
+    return [NT_str, NT_null, NT_union]
+
+
+A_str = typing_extensions.Annotated[str, "meta"]
+A_null_str = typing_extensions.Annotated[
+    typing.Union[str, None], "other_meta", "null"
+]
+A_union = typing_extensions.Annotated[typing.Union[str, int], "other_meta"]
+A_null_union = typing_extensions.Annotated[
+    typing.Union[str, int, None], "other_meta", "null"
+]
+
+
+def annotated_l():
+    return [A_str, A_null_str, A_union, A_null_union]
+
+
+def all_types():
+    return (
+        union_types()
+        + null_union_types()
+        + type_aliases()
+        + new_types()
+        + annotated_l()
+    )
+
+
+def exec_code(code: str, *vars: str) -> typing.Any:
+    assert vars
+    scope = {}
+    exec(code, None, scope)
+    if len(vars) == 1:
+        return scope[vars[0]]
+    return [scope[name] for name in vars]
+
+
+class TestTestingThings(fixtures.TestBase):
+    def test_unions_are_the_same(self):
+        # no need to test typing_extensions.Union, typing_extensions.Optional
+        is_(typing.Union, typing_extensions.Union)
+        is_(typing.Optional, typing_extensions.Optional)
+        if py312:
+            is_(typing.TypeAliasType, typing_extensions.TypeAliasType)
+
+    def test_make_union(self):
+        v = int, str
+        eq_(typing.Union[int, str], typing.Union.__getitem__(v))
+        if py311:
+            # need eval since it's a syntax error in python < 3.11
+            eq_(typing.Union[int, str], eval("typing.Union[*(int, str)]"))
+            eq_(typing.Union[int, str], eval("typing.Union[*v]"))
+
+    @requires.python312
+    def test_make_type_alias_type(self):
+        # verify that TypeAliasType('foo', int) it the same as 'type foo = int'
+        x_type = exec_code("type x = int", "x")
+        x = typing.TypeAliasType("x", int)
+
+        eq_(type(x_type), type(x))
+        eq_(x_type.__name__, x.__name__)
+        eq_(x_type.__value__, x.__value__)
+
+    def test_make_fw_ref(self):
+        eq_(make_fw_ref("str"), typing.ForwardRef("str"))
+        eq_(make_fw_ref("str|int"), typing.ForwardRef("str|int"))
+        eq_(
+            make_fw_ref("Optional[Union[str, int]]"),
+            typing.ForwardRef("Optional[Union[str, int]]"),
+        )
+
+
+class TestTyping(fixtures.TestBase):
+    def test_is_pep593(self):
+        eq_(sa_typing.is_pep593(str), False)
+        eq_(sa_typing.is_pep593(None), False)
+        eq_(sa_typing.is_pep593(typing_extensions.Annotated[int, "a"]), True)
+        if py310:
+            eq_(sa_typing.is_pep593(typing.Annotated[int, "a"]), True)
+
+        for t in annotated_l():
+            eq_(sa_typing.is_pep593(t), True)
+        for t in (
+            union_types() + null_union_types() + type_aliases() + new_types()
+        ):
+            eq_(sa_typing.is_pep593(t), False)
+
+    def test_is_literal(self):
+        if py38:
+            eq_(sa_typing.is_literal(typing.Literal["a"]), True)
+        eq_(sa_typing.is_literal(typing_extensions.Literal["a"]), True)
+        eq_(sa_typing.is_literal(None), False)
+        for t in all_types():
+            eq_(sa_typing.is_literal(t), False)
+
+    def test_is_newtype(self):
+        eq_(sa_typing.is_newtype(str), False)
+
+        for t in new_types():
+            eq_(sa_typing.is_newtype(t), True)
+        for t in (
+            union_types() + null_union_types() + type_aliases() + annotated_l()
+        ):
+            eq_(sa_typing.is_newtype(t), False)
+
+    def test_is_generic(self):
+        class W(typing.Generic[TV]):
+            pass
+
+        eq_(sa_typing.is_generic(typing.List[int]), True)
+        eq_(sa_typing.is_generic(W), False)
+        eq_(sa_typing.is_generic(W[str]), True)
+
+        if py312:
+            t = exec_code("class W[T]: pass", "W")
+            eq_(sa_typing.is_generic(t), False)
+            eq_(sa_typing.is_generic(t[int]), True)
+
+        for t in all_types():
+            eq_(sa_typing.is_literal(t), False)
+
+    def test_is_pep695(self):
+        eq_(sa_typing.is_pep695(str), False)
+        for t in (
+            union_types() + null_union_types() + new_types() + annotated_l()
+        ):
+            eq_(sa_typing.is_pep695(t), False)
+        for t in type_aliases():
+            eq_(sa_typing.is_pep695(t), True)
+
+    def test_pep695_value(self):
+        eq_(sa_typing.pep695_values(int), {int})
+        eq_(
+            sa_typing.pep695_values(typing.Union[int, str]),
+            {typing.Union[int, str]},
+        )
+
+        for t in (
+            union_types() + null_union_types() + new_types() + annotated_l()
+        ):
+            eq_(sa_typing.pep695_values(t), {t})
+
+        eq_(
+            sa_typing.pep695_values(typing.Union[int, TA_int]),
+            {typing.Union[int, TA_int]},
+        )
+
+        eq_(sa_typing.pep695_values(TA_int), {int})
+        eq_(sa_typing.pep695_values(TA_union), {int, str})
+        eq_(sa_typing.pep695_values(TA_null_union), {int, str, None})
+        eq_(sa_typing.pep695_values(TA_null_union2), {int, str, None})
+        eq_(
+            sa_typing.pep695_values(TA_null_union3),
+            {int, typing.ForwardRef("typing.Union[None, bool]")},
+        )
+        eq_(
+            sa_typing.pep695_values(TA_null_union4),
+            {int, typing.ForwardRef("TA_null_union2")},
+        )
+        eq_(sa_typing.pep695_values(TA_union_ta), {int, str})
+        eq_(sa_typing.pep695_values(TA_null_union_ta), {int, str, None, float})
+        eq_(
+            sa_typing.pep695_values(TA_list),
+            {int, str, typing.List[typing.ForwardRef("TA_list")]},
+        )
+        eq_(
+            sa_typing.pep695_values(TA_recursive),
+            {typing.ForwardRef("TA_recursive"), str},
+        )
+        eq_(
+            sa_typing.pep695_values(TA_null_recursive),
+            {typing.ForwardRef("TA_recursive"), str, None},
+        )
+        eq_(
+            sa_typing.pep695_values(TA_recursive_a),
+            {typing.ForwardRef("TA_recursive_b"), int},
+        )
+        eq_(
+            sa_typing.pep695_values(TA_recursive_b),
+            {typing.ForwardRef("TA_recursive_a"), str},
+        )
+
+    def test_is_fwd_ref(self):
+        eq_(sa_typing.is_fwd_ref(int), False)
+        eq_(sa_typing.is_fwd_ref(make_fw_ref("str")), True)
+        eq_(sa_typing.is_fwd_ref(typing.Union[str, int]), False)
+        eq_(sa_typing.is_fwd_ref(typing.Union["str", int]), False)
+        eq_(sa_typing.is_fwd_ref(typing.Union["str", int], True), True)
+
+        for t in all_types():
+            eq_(sa_typing.is_fwd_ref(t), False)
+
+    def test_de_optionalize_union_types(self):
+        fn = sa_typing.de_optionalize_union_types
+
+        eq_(
+            fn(typing.Optional[typing.Union[int, str]]), typing.Union[int, str]
+        )
+        eq_(fn(typing.Union[int, str, None]), typing.Union[int, str])
+        eq_(fn(typing.Union[int, str, "None"]), typing.Union[int, str])
+
+        eq_(fn(make_fw_ref("None")), typing_extensions.Never)
+        eq_(fn(make_fw_ref("typing.Union[None]")), typing_extensions.Never)
+        eq_(fn(make_fw_ref("Union[None, str]")), typing.ForwardRef("str"))
+        eq_(
+            fn(make_fw_ref("Union[None, str, int]")),
+            typing.Union["str", "int"],
+        )
+        eq_(fn(make_fw_ref("Optional[int]")), typing.ForwardRef("int"))
+        eq_(
+            fn(make_fw_ref("typing.Optional[Union[int | str]]")),
+            typing.ForwardRef("Union[int | str]"),
+        )
+
+        for t in null_union_types():
+            res = fn(t)
+            eq_(sa_typing.is_union(res), True)
+            eq_(type(None) not in res.__args__, True)
+
+        for t in union_types() + type_aliases() + new_types() + annotated_l():
+            eq_(fn(t), t)
+
+        eq_(
+            fn(make_fw_ref("Union[typing.Dict[str, int], int, None]")),
+            typing.Union["typing.Dict[str, int]", "int"],
+        )
+
+    def test_make_union_type(self):
+        eq_(sa_typing.make_union_type(int), int)
+        eq_(sa_typing.make_union_type(None), type(None))
+        eq_(sa_typing.make_union_type(int, str), typing.Union[int, str])
+        eq_(
+            sa_typing.make_union_type(int, typing.Optional[str]),
+            typing.Union[int, str, None],
+        )
+        eq_(
+            sa_typing.make_union_type(int, typing.Union[str, bool]),
+            typing.Union[int, str, bool],
+        )
+        eq_(
+            sa_typing.make_union_type(bool, TA_int, NT_str),
+            typing.Union[bool, TA_int, NT_str],
+        )
+
+    def test_includes_none(self):
+        eq_(sa_typing.includes_none(None), True)
+        eq_(sa_typing.includes_none(type(None)), True)
+        eq_(sa_typing.includes_none(typing.ForwardRef("None")), True)
+        eq_(sa_typing.includes_none(int), False)
+        for t in union_types():
+            eq_(sa_typing.includes_none(t), False)
+
+        for t in null_union_types():
+            eq_(sa_typing.includes_none(t), True, str(t))
+
+        # TODO: these are false negatives
+        false_negative = {
+            TA_null_union4,  # does not evaluate FW ref
+        }
+        for t in type_aliases() + new_types():
+            if t in false_negative:
+                exp = False
+            else:
+                exp = "null" in t.__name__
+            eq_(sa_typing.includes_none(t), exp, str(t))
+
+        for t in annotated_l():
+            eq_(
+                sa_typing.includes_none(t),
+                "null" in sa_typing.get_args(t),
+                str(t),
+            )
+        # nested things
+        eq_(sa_typing.includes_none(typing.Union[int, "None"]), True)
+        eq_(sa_typing.includes_none(typing.Union[bool, TA_null_union]), True)
+        eq_(sa_typing.includes_none(typing.Union[bool, NT_null]), True)
+        # nested fw
+        eq_(
+            sa_typing.includes_none(
+                typing.Union[int, "typing.Union[str, None]"]
+            ),
+            True,
+        )
+        eq_(
+            sa_typing.includes_none(
+                typing.Union[int, "typing.Union[int, str]"]
+            ),
+            False,
+        )
+
+        # there are not supported. should return True
+        eq_(
+            sa_typing.includes_none(typing.Union[bool, "TA_null_union"]), False
+        )
+        eq_(sa_typing.includes_none(typing.Union[bool, "NT_null"]), False)
+
+    def test_is_union(self):
+        eq_(sa_typing.is_union(str), False)
+        for t in union_types() + null_union_types():
+            eq_(sa_typing.is_union(t), True)
+        for t in type_aliases() + new_types() + annotated_l():
+            eq_(sa_typing.is_union(t), False)
index 85c419e94e81f18eac042a7219ba901ca36e2c1c..de8712c852343b54c356695ec7f0413f4ad7bcd3 100644 (file)
@@ -4,9 +4,6 @@ import inspect
 from pathlib import Path
 import pickle
 import sys
-import typing
-
-import typing_extensions
 
 from sqlalchemy import exc
 from sqlalchemy import sql
@@ -42,7 +39,6 @@ from sqlalchemy.util import WeakSequence
 from sqlalchemy.util._collections import merge_lists_w_ordering
 from sqlalchemy.util._has_cy import _import_cy_extensions
 from sqlalchemy.util._has_cy import HAS_CYEXTENSION
-from sqlalchemy.util.typing import is_union
 
 
 class WeakSequenceTest(fixtures.TestBase):
@@ -3634,11 +3630,3 @@ class CyExtensionTest(fixtures.TestBase):
             for f in cython_files
         }
         eq_({m.__name__ for m in ext}, set(names))
-
-
-class TypingTest(fixtures.TestBase):
-    def test_is_union(self):
-        assert is_union(typing.Union[str, int])
-        assert is_union(typing_extensions.Union[str, int])
-        if compat.py310:
-            assert is_union(str | int)
index c34d54169e8fcb1bab4f388eb99fe572c1564306..165f43b42d3897c750fb7a48bbc20ccdadbdbd65 100644 (file)
@@ -1,8 +1,8 @@
 """This file includes annotation-sensitive tests while having
 ``from __future__ import annotations`` in effect.
 
-Only tests that don't have an equivalent in ``test_typed_mappings`` are
-specified here. All test from ``test_typed_mappings`` are copied over to
+Only tests that don't have an equivalent in ``test_typed_mapping`` are
+specified here. All test from ``test_typed_mapping`` are copied over to
 the ``test_tm_future_annotations_sync`` by the ``sync_test_file`` script.
 """
 
index a2eac4d7f4f989237be6303acea9d3bc42210856..4b3792663885107ed68e82cd43dc8981ff91d134 100644 (file)
@@ -96,8 +96,9 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
-from sqlalchemy.testing import skip_test
+from sqlalchemy.testing import requires
 from sqlalchemy.testing import Variation
+from sqlalchemy.testing.assertions import ne_
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.util import compat
 from sqlalchemy.util.typing import Annotated
@@ -118,11 +119,6 @@ _StrTypeAlias: TypeAlias = str
 _StrPep695: TypeAlias = str
 _UnionPep695: TypeAlias = Union[_SomeDict1, _SomeDict2]
 
-_Literal695: TypeAlias = Literal["to-do", "in-progress", "done"]
-_Recursive695_0: TypeAlias = _Literal695
-_Recursive695_1: TypeAlias = _Recursive695_0
-_Recursive695_2: TypeAlias = _Recursive695_1
-
 if compat.py38:
     _TypingLiteral = typing.Literal["a", "b"]
 _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"]
@@ -147,16 +143,16 @@ type _UnionPep695 = _SomeDict1 | _SomeDict2
 type _StrPep695 = str
 
 type strtypalias_keyword = Annotated[str, mapped_column(info={"hi": "there"})]
-
-strtypalias_tat: typing.TypeAliasType = Annotated[
+type strtypalias_keyword_nested = int | Annotated[
+    str, mapped_column(info={"hi": "there"})]
+strtypalias_ta: typing.TypeAlias = Annotated[
     str, mapped_column(info={"hi": "there"})]
-
 strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})]
 
 type _Literal695 = Literal["to-do", "in-progress", "done"]
-type _Recursive695_0 = _Literal695
-type _Recursive695_1 = _Recursive695_0
-type _Recursive695_2 = _Recursive695_1
+type _RecursiveLiteral695 = _Literal695
+
+type _JsonPep695 = _JsonPep604
 """,
         globals(),
     )
@@ -856,6 +852,84 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         eq_(Test.__table__.c.data.type.length, 30)
         is_(Test.__table__.c.structure.type._type_affinity, JSON)
 
+    @testing.variation(
+        "option",
+        [
+            "plain",
+            "union",
+            "union_604",
+            "union_null",
+            "union_null_604",
+            "optional",
+            "optional_union",
+            "optional_union_604",
+        ],
+    )
+    @testing.variation("in_map", ["yes", "no", "value"])
+    @testing.requires.python312
+    def test_pep695_behavior(self, decl_base, in_map, option):
+        """Issue #11955"""
+        global tat
+
+        if option.plain:
+            tat = TypeAliasType("tat", str)
+        elif option.union:
+            tat = TypeAliasType("tat", Union[str, int])
+        elif option.union_604:
+            tat = TypeAliasType("tat", str | int)
+        elif option.union_null:
+            tat = TypeAliasType("tat", Union[str, int, None])
+        elif option.union_null_604:
+            tat = TypeAliasType("tat", str | int | None)
+        elif option.optional:
+            tat = TypeAliasType("tat", Optional[str])
+        elif option.optional_union:
+            tat = TypeAliasType("tat", Optional[Union[str, int]])
+        elif option.optional_union_604:
+            tat = TypeAliasType("tat", Optional[str | int])
+        else:
+            option.fail()
+
+        if in_map.yes:
+            decl_base.registry.update_type_annotation_map({tat: String(99)})
+        elif in_map.value:
+            decl_base.registry.update_type_annotation_map(
+                {tat.__value__: String(99)}
+            )
+
+        def declare():
+            class Test(decl_base):
+                __tablename__ = "test"
+                id: Mapped[int] = mapped_column(primary_key=True)
+                data: Mapped[tat]
+
+            return Test.__table__.c.data
+
+        if in_map.yes:
+            col = declare()
+            length = 99
+        elif in_map.value or option.optional or option.plain:
+            with expect_deprecated(
+                "Matching the provided TypeAliasType 'tat' on its "
+                "resolved value without matching it in the "
+                "type_annotation_map is deprecated; add this type to the "
+                "type_annotation_map to allow it to match explicitly.",
+            ):
+                col = declare()
+            length = 99 if in_map.value else None
+        else:
+            with expect_raises_message(
+                exc.ArgumentError,
+                "Could not locate SQLAlchemy Core type for Python type",
+            ):
+                declare()
+            return
+
+        is_true(isinstance(col.type, String))
+        eq_(col.type.length, length)
+        nullable = "null" in option.name or "optional" in option.name
+        eq_(col.nullable, nullable)
+
     @testing.requires.python312
     def test_pep695_typealias_as_typemap_keys(
         self, decl_base: Type[DeclarativeBase]
@@ -876,12 +950,23 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         eq_(Test.__table__.c.data.type.length, 30)
         is_(Test.__table__.c.structure.type._type_affinity, JSON)
 
-    @testing.variation("alias_type", ["none", "typekeyword", "typealiastype"])
+    @testing.variation(
+        "alias_type",
+        ["none", "typekeyword", "typealias", "typekeyword_nested"],
+    )
     @testing.requires.python312
     def test_extract_pep593_from_pep695(
         self, decl_base: Type[DeclarativeBase], alias_type
     ):
         """test #11130"""
+        if alias_type.typekeyword:
+            decl_base.registry.update_type_annotation_map(
+                {strtypalias_keyword: VARCHAR(33)}  # noqa: F821
+            )
+        if alias_type.typekeyword_nested:
+            decl_base.registry.update_type_annotation_map(
+                {strtypalias_keyword_nested: VARCHAR(42)}  # noqa: F821
+            )
 
         class MyClass(decl_base):
             __tablename__ = "my_table"
@@ -890,33 +975,96 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
             if alias_type.typekeyword:
                 data_one: Mapped[strtypalias_keyword]  # noqa: F821
-            elif alias_type.typealiastype:
-                data_one: Mapped[strtypalias_tat]  # noqa: F821
+            elif alias_type.typealias:
+                data_one: Mapped[strtypalias_ta]  # noqa: F821
             elif alias_type.none:
                 data_one: Mapped[strtypalias_plain]  # noqa: F821
+            elif alias_type.typekeyword_nested:
+                data_one: Mapped[strtypalias_keyword_nested]  # noqa: F821
             else:
                 alias_type.fail()
 
         table = MyClass.__table__
         assert table is not None
 
-        eq_(MyClass.data_one.expression.info, {"hi": "there"})
+        if alias_type.typekeyword_nested:
+            # a nested annotation is not supported
+            eq_(MyClass.data_one.expression.info, {})
+        else:
+            eq_(MyClass.data_one.expression.info, {"hi": "there"})
 
+        if alias_type.typekeyword:
+            eq_(MyClass.data_one.type.length, 33)
+        elif alias_type.typekeyword_nested:
+            eq_(MyClass.data_one.type.length, 42)
+        else:
+            eq_(MyClass.data_one.type.length, None)
+
+    @testing.variation("type_", ["literal", "recursive", "not_literal"])
+    @testing.combinations(True, False, argnames="in_map")
     @testing.requires.python312
-    def test_pep695_literal_defaults_to_enum(self, decl_base):
+    def test_pep695_literal_defaults_to_enum(self, decl_base, type_, in_map):
         """test #11305."""
 
-        class Foo(decl_base):
-            __tablename__ = "footable"
+        def declare():
+            class Foo(decl_base):
+                __tablename__ = "footable"
 
-            id: Mapped[int] = mapped_column(primary_key=True)
-            status: Mapped[_Literal695]
-            r2: Mapped[_Recursive695_2]
+                id: Mapped[int] = mapped_column(primary_key=True)
+                if type_.recursive:
+                    status: Mapped[_RecursiveLiteral695]  # noqa: F821
+                elif type_.literal:
+                    status: Mapped[_Literal695]  # noqa: F821
+                elif type_.not_literal:
+                    status: Mapped[_StrPep695]  # noqa: F821
+                else:
+                    type_.fail()
+
+            return Foo
 
-        for col in (Foo.__table__.c.status, Foo.__table__.c.r2):
+        if in_map:
+            decl_base.registry.update_type_annotation_map(
+                {
+                    _Literal695: Enum(enum.Enum),  # noqa: F821
+                    _RecursiveLiteral695: Enum(enum.Enum),  # noqa: F821
+                    _StrPep695: Enum(enum.Enum),  # noqa: F821
+                }
+            )
+            if type_.recursive:
+                with expect_deprecated(
+                    "Mapping recursive TypeAliasType '.+' that resolve to "
+                    "literal to generate an Enum is deprecated. SQLAlchemy "
+                    "2.1 will not support this use case. Please avoid using "
+                    "recursing TypeAliasType",
+                ):
+                    Foo = declare()
+            elif type_.literal:
+                Foo = declare()
+            else:
+                with expect_raises_message(
+                    exc.ArgumentError,
+                    "Can't associate TypeAliasType '.+' to an Enum "
+                    "since it's not a direct alias of a Literal. Only "
+                    "aliases in this form `type my_alias = Literal.'a', "
+                    "'b'.` are supported when generating Enums.",
+                ):
+                    declare()
+                return
+        else:
+            with expect_deprecated(
+                "Matching the provided TypeAliasType '.*' on its "
+                "resolved value without matching it in the "
+                "type_annotation_map is deprecated; add this type to the "
+                "type_annotation_map to allow it to match explicitly.",
+            ):
+                Foo = declare()
+        col = Foo.__table__.c.status
+        if in_map and not type_.not_literal:
             is_true(isinstance(col.type, Enum))
             eq_(col.type.enums, ["to-do", "in-progress", "done"])
             is_(col.type.native_enum, False)
+        else:
+            is_true(isinstance(col.type, String))
 
     @testing.requires.python38
     def test_typing_literal_identity(self, decl_base):
@@ -1233,6 +1381,33 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         eq_(MyClass.__table__.c.data_four.type.length, 150)
         is_false(MyClass.__table__.c.data_four.nullable)
 
+    def test_newtype_missing_from_map(self, decl_base):
+        global str50
+
+        str50 = NewType("str50", str)
+
+        if compat.py310:
+            text = ".*str50"
+        else:
+            # NewTypes before 3.10 had a very bad repr
+            # <function NewType.<locals>.new_type at 0x...>
+            text = ".*NewType.*"
+
+        with expect_deprecated(
+            f"Matching the provided NewType '{text}' on its "
+            "resolved value without matching it in the "
+            "type_annotation_map is deprecated; add this type to the "
+            "type_annotation_map to allow it to match explicitly.",
+        ):
+
+            class MyClass(decl_base):
+                __tablename__ = "my_table"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+                data_one: Mapped[str50]
+
+        is_true(isinstance(MyClass.data_one.type, String))
+
     def test_extract_base_type_from_pep593(
         self, decl_base: Type[DeclarativeBase]
     ):
@@ -1724,39 +1899,40 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             is_(getattr(Element.__table__.c.data, paramname), override_value)
 
-    @testing.variation("union", ["union", "pep604"])
-    @testing.variation("typealias", ["legacy", "pep695"])
-    def test_unions(self, union, typealias):
+    @testing.variation(
+        "union",
+        [
+            "union",
+            ("pep604", requires.python310),
+            "union_null",
+            ("pep604_null", requires.python310),
+        ],
+    )
+    def test_unions(self, union):
+        global UnionType
         our_type = Numeric(10, 2)
 
         if union.union:
             UnionType = Union[float, Decimal]
+        elif union.union_null:
+            UnionType = Union[float, Decimal, None]
         elif union.pep604:
-            if not compat.py310:
-                skip_test("Required Python 3.10")
             UnionType = float | Decimal
+        elif union.pep604_null:
+            UnionType = float | Decimal | None
         else:
             union.fail()
 
-        if typealias.legacy:
-            UnionTypeAlias = UnionType
-        elif typealias.pep695:
-            # same as type UnionTypeAlias = UnionType
-            UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType)
-        else:
-            typealias.fail()
-
         class Base(DeclarativeBase):
-            type_annotation_map = {UnionTypeAlias: our_type}
+            type_annotation_map = {UnionType: our_type}
 
         class User(Base):
             __tablename__ = "users"
-            __table__: Table
 
             id: Mapped[int] = mapped_column(primary_key=True)
 
-            data: Mapped[Union[float, Decimal]] = mapped_column()
-            reverse_data: Mapped[Union[Decimal, float]] = mapped_column()
+            data: Mapped[Union[float, Decimal]]
+            reverse_data: Mapped[Union[Decimal, float]]
 
             optional_data: Mapped[Optional[Union[float, Decimal]]] = (
                 mapped_column()
@@ -1773,6 +1949,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 mapped_column()
             )
 
+            refer_union: Mapped[UnionType]
+            refer_union_optional: Mapped[Optional[UnionType]]
+
             float_data: Mapped[float] = mapped_column()
             decimal_data: Mapped[Decimal] = mapped_column()
 
@@ -1788,65 +1967,54 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                     mapped_column()
                 )
 
-            if compat.py312:
-                MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal)
-                pep695_data: Mapped[MyTypeAlias] = mapped_column()
-
-        is_(User.__table__.c.data.type, our_type)
-        is_false(User.__table__.c.data.nullable)
-        is_(User.__table__.c.reverse_data.type, our_type)
-        is_(User.__table__.c.optional_data.type, our_type)
-        is_true(User.__table__.c.optional_data.nullable)
+        info = [
+            ("data", False),
+            ("reverse_data", False),
+            ("optional_data", True),
+            ("reverse_optional_data", True),
+            ("reverse_u_optional_data", True),
+            ("refer_union", "null" in union.name),
+            ("refer_union_optional", True),
+        ]
+        if compat.py310:
+            info += [
+                ("pep604_data", False),
+                ("pep604_reverse", False),
+                ("pep604_optional", True),
+                ("pep604_data_fwd", False),
+                ("pep604_reverse_fwd", False),
+                ("pep604_optional_fwd", True),
+            ]
 
-        is_(User.__table__.c.reverse_optional_data.type, our_type)
-        is_(User.__table__.c.reverse_u_optional_data.type, our_type)
-        is_true(User.__table__.c.reverse_optional_data.nullable)
-        is_true(User.__table__.c.reverse_u_optional_data.nullable)
+        for name, nullable in info:
+            col = User.__table__.c[name]
+            is_(col.type, our_type, name)
+            is_(col.nullable, nullable, name)
 
         is_true(isinstance(User.__table__.c.float_data.type, Float))
+        ne_(User.__table__.c.float_data.type, our_type)
 
-        is_not(User.__table__.c.decimal_data.type, our_type)
+        is_true(isinstance(User.__table__.c.decimal_data.type, Numeric))
+        ne_(User.__table__.c.decimal_data.type, our_type)
 
-        if compat.py310:
-            for suffix in ("", "_fwd"):
-                data_col = User.__table__.c[f"pep604_data{suffix}"]
-                reverse_col = User.__table__.c[f"pep604_reverse{suffix}"]
-                optional_col = User.__table__.c[f"pep604_optional{suffix}"]
-                is_(data_col.type, our_type)
-                is_false(data_col.nullable)
-                is_(reverse_col.type, our_type)
-                is_false(reverse_col.nullable)
-                is_(optional_col.type, our_type)
-                is_true(optional_col.nullable)
-
-        if compat.py312:
-            is_(User.__table__.c.pep695_data.type, our_type)
-
-    @testing.variation("union", ["union", "pep604"])
+    @testing.variation(
+        "union",
+        [
+            "union",
+            ("pep604", requires.python310),
+            ("pep695", requires.python312),
+        ],
+    )
     def test_optional_in_annotation_map(self, union):
-        """SQLAlchemy's behaviour is clear: an optional type means the column
-        is inferred as nullable. Some types which a user may want to put in the
-        type annotation map are already optional. JSON is a good example
-        because without any constraint, the type can be None via JSON null or
-        SQL NULL.
-
-        By permitting optional types in the type annotation map, everything
-        just works, and mapped_column(nullable=False) is available if desired.
-
-        See issue #11370
-        """
+        """See issue #11370"""
 
         class Base(DeclarativeBase):
             if union.union:
-                type_annotation_map = {
-                    _Json: JSON,
-                }
+                type_annotation_map = {_Json: JSON}
             elif union.pep604:
-                if not compat.py310:
-                    skip_test("Requires Python 3.10+")
-                type_annotation_map = {
-                    _JsonPep604: JSON,
-                }
+                type_annotation_map = {_JsonPep604: JSON}
+            elif union.pep695:
+                type_annotation_map = {_JsonPep695: JSON}  # noqa: F821
             else:
                 union.fail()
 
@@ -1858,10 +2026,13 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 json1: Mapped[_Json]
                 json2: Mapped[_Json] = mapped_column(nullable=False)
             elif union.pep604:
-                if not compat.py310:
-                    skip_test("Requires Python 3.10+")
                 json1: Mapped[_JsonPep604]
                 json2: Mapped[_JsonPep604] = mapped_column(nullable=False)
+            elif union.pep695:
+                json1: Mapped[_JsonPep695]  # noqa: F821
+                json2: Mapped[_JsonPep695] = mapped_column(  # noqa: F821
+                    nullable=False
+                )
             else:
                 union.fail()
 
index 5026e676a76d52643fbdc90f60f2b0ce48bdbf07..f1970f2183bea10e0b56e9ead8d5a115dfc34790 100644 (file)
@@ -87,8 +87,9 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
-from sqlalchemy.testing import skip_test
+from sqlalchemy.testing import requires
 from sqlalchemy.testing import Variation
+from sqlalchemy.testing.assertions import ne_
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.util import compat
 from sqlalchemy.util.typing import Annotated
@@ -109,11 +110,6 @@ _StrTypeAlias: TypeAlias = str
 _StrPep695: TypeAlias = str
 _UnionPep695: TypeAlias = Union[_SomeDict1, _SomeDict2]
 
-_Literal695: TypeAlias = Literal["to-do", "in-progress", "done"]
-_Recursive695_0: TypeAlias = _Literal695
-_Recursive695_1: TypeAlias = _Recursive695_0
-_Recursive695_2: TypeAlias = _Recursive695_1
-
 if compat.py38:
     _TypingLiteral = typing.Literal["a", "b"]
 _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"]
@@ -138,16 +134,16 @@ type _UnionPep695 = _SomeDict1 | _SomeDict2
 type _StrPep695 = str
 
 type strtypalias_keyword = Annotated[str, mapped_column(info={"hi": "there"})]
-
-strtypalias_tat: typing.TypeAliasType = Annotated[
+type strtypalias_keyword_nested = int | Annotated[
+    str, mapped_column(info={"hi": "there"})]
+strtypalias_ta: typing.TypeAlias = Annotated[
     str, mapped_column(info={"hi": "there"})]
-
 strtypalias_plain = Annotated[str, mapped_column(info={"hi": "there"})]
 
 type _Literal695 = Literal["to-do", "in-progress", "done"]
-type _Recursive695_0 = _Literal695
-type _Recursive695_1 = _Recursive695_0
-type _Recursive695_2 = _Recursive695_1
+type _RecursiveLiteral695 = _Literal695
+
+type _JsonPep695 = _JsonPep604
 """,
         globals(),
     )
@@ -847,6 +843,84 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         eq_(Test.__table__.c.data.type.length, 30)
         is_(Test.__table__.c.structure.type._type_affinity, JSON)
 
+    @testing.variation(
+        "option",
+        [
+            "plain",
+            "union",
+            "union_604",
+            "union_null",
+            "union_null_604",
+            "optional",
+            "optional_union",
+            "optional_union_604",
+        ],
+    )
+    @testing.variation("in_map", ["yes", "no", "value"])
+    @testing.requires.python312
+    def test_pep695_behavior(self, decl_base, in_map, option):
+        """Issue #11955"""
+        # anno only: global tat
+
+        if option.plain:
+            tat = TypeAliasType("tat", str)
+        elif option.union:
+            tat = TypeAliasType("tat", Union[str, int])
+        elif option.union_604:
+            tat = TypeAliasType("tat", str | int)
+        elif option.union_null:
+            tat = TypeAliasType("tat", Union[str, int, None])
+        elif option.union_null_604:
+            tat = TypeAliasType("tat", str | int | None)
+        elif option.optional:
+            tat = TypeAliasType("tat", Optional[str])
+        elif option.optional_union:
+            tat = TypeAliasType("tat", Optional[Union[str, int]])
+        elif option.optional_union_604:
+            tat = TypeAliasType("tat", Optional[str | int])
+        else:
+            option.fail()
+
+        if in_map.yes:
+            decl_base.registry.update_type_annotation_map({tat: String(99)})
+        elif in_map.value:
+            decl_base.registry.update_type_annotation_map(
+                {tat.__value__: String(99)}
+            )
+
+        def declare():
+            class Test(decl_base):
+                __tablename__ = "test"
+                id: Mapped[int] = mapped_column(primary_key=True)
+                data: Mapped[tat]
+
+            return Test.__table__.c.data
+
+        if in_map.yes:
+            col = declare()
+            length = 99
+        elif in_map.value or option.optional or option.plain:
+            with expect_deprecated(
+                "Matching the provided TypeAliasType 'tat' on its "
+                "resolved value without matching it in the "
+                "type_annotation_map is deprecated; add this type to the "
+                "type_annotation_map to allow it to match explicitly.",
+            ):
+                col = declare()
+            length = 99 if in_map.value else None
+        else:
+            with expect_raises_message(
+                exc.ArgumentError,
+                "Could not locate SQLAlchemy Core type for Python type",
+            ):
+                declare()
+            return
+
+        is_true(isinstance(col.type, String))
+        eq_(col.type.length, length)
+        nullable = "null" in option.name or "optional" in option.name
+        eq_(col.nullable, nullable)
+
     @testing.requires.python312
     def test_pep695_typealias_as_typemap_keys(
         self, decl_base: Type[DeclarativeBase]
@@ -867,12 +941,23 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         eq_(Test.__table__.c.data.type.length, 30)
         is_(Test.__table__.c.structure.type._type_affinity, JSON)
 
-    @testing.variation("alias_type", ["none", "typekeyword", "typealiastype"])
+    @testing.variation(
+        "alias_type",
+        ["none", "typekeyword", "typealias", "typekeyword_nested"],
+    )
     @testing.requires.python312
     def test_extract_pep593_from_pep695(
         self, decl_base: Type[DeclarativeBase], alias_type
     ):
         """test #11130"""
+        if alias_type.typekeyword:
+            decl_base.registry.update_type_annotation_map(
+                {strtypalias_keyword: VARCHAR(33)}  # noqa: F821
+            )
+        if alias_type.typekeyword_nested:
+            decl_base.registry.update_type_annotation_map(
+                {strtypalias_keyword_nested: VARCHAR(42)}  # noqa: F821
+            )
 
         class MyClass(decl_base):
             __tablename__ = "my_table"
@@ -881,33 +966,96 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
             if alias_type.typekeyword:
                 data_one: Mapped[strtypalias_keyword]  # noqa: F821
-            elif alias_type.typealiastype:
-                data_one: Mapped[strtypalias_tat]  # noqa: F821
+            elif alias_type.typealias:
+                data_one: Mapped[strtypalias_ta]  # noqa: F821
             elif alias_type.none:
                 data_one: Mapped[strtypalias_plain]  # noqa: F821
+            elif alias_type.typekeyword_nested:
+                data_one: Mapped[strtypalias_keyword_nested]  # noqa: F821
             else:
                 alias_type.fail()
 
         table = MyClass.__table__
         assert table is not None
 
-        eq_(MyClass.data_one.expression.info, {"hi": "there"})
+        if alias_type.typekeyword_nested:
+            # a nested annotation is not supported
+            eq_(MyClass.data_one.expression.info, {})
+        else:
+            eq_(MyClass.data_one.expression.info, {"hi": "there"})
 
+        if alias_type.typekeyword:
+            eq_(MyClass.data_one.type.length, 33)
+        elif alias_type.typekeyword_nested:
+            eq_(MyClass.data_one.type.length, 42)
+        else:
+            eq_(MyClass.data_one.type.length, None)
+
+    @testing.variation("type_", ["literal", "recursive", "not_literal"])
+    @testing.combinations(True, False, argnames="in_map")
     @testing.requires.python312
-    def test_pep695_literal_defaults_to_enum(self, decl_base):
+    def test_pep695_literal_defaults_to_enum(self, decl_base, type_, in_map):
         """test #11305."""
 
-        class Foo(decl_base):
-            __tablename__ = "footable"
+        def declare():
+            class Foo(decl_base):
+                __tablename__ = "footable"
 
-            id: Mapped[int] = mapped_column(primary_key=True)
-            status: Mapped[_Literal695]
-            r2: Mapped[_Recursive695_2]
+                id: Mapped[int] = mapped_column(primary_key=True)
+                if type_.recursive:
+                    status: Mapped[_RecursiveLiteral695]  # noqa: F821
+                elif type_.literal:
+                    status: Mapped[_Literal695]  # noqa: F821
+                elif type_.not_literal:
+                    status: Mapped[_StrPep695]  # noqa: F821
+                else:
+                    type_.fail()
+
+            return Foo
 
-        for col in (Foo.__table__.c.status, Foo.__table__.c.r2):
+        if in_map:
+            decl_base.registry.update_type_annotation_map(
+                {
+                    _Literal695: Enum(enum.Enum),  # noqa: F821
+                    _RecursiveLiteral695: Enum(enum.Enum),  # noqa: F821
+                    _StrPep695: Enum(enum.Enum),  # noqa: F821
+                }
+            )
+            if type_.recursive:
+                with expect_deprecated(
+                    "Mapping recursive TypeAliasType '.+' that resolve to "
+                    "literal to generate an Enum is deprecated. SQLAlchemy "
+                    "2.1 will not support this use case. Please avoid using "
+                    "recursing TypeAliasType",
+                ):
+                    Foo = declare()
+            elif type_.literal:
+                Foo = declare()
+            else:
+                with expect_raises_message(
+                    exc.ArgumentError,
+                    "Can't associate TypeAliasType '.+' to an Enum "
+                    "since it's not a direct alias of a Literal. Only "
+                    "aliases in this form `type my_alias = Literal.'a', "
+                    "'b'.` are supported when generating Enums.",
+                ):
+                    declare()
+                return
+        else:
+            with expect_deprecated(
+                "Matching the provided TypeAliasType '.*' on its "
+                "resolved value without matching it in the "
+                "type_annotation_map is deprecated; add this type to the "
+                "type_annotation_map to allow it to match explicitly.",
+            ):
+                Foo = declare()
+        col = Foo.__table__.c.status
+        if in_map and not type_.not_literal:
             is_true(isinstance(col.type, Enum))
             eq_(col.type.enums, ["to-do", "in-progress", "done"])
             is_(col.type.native_enum, False)
+        else:
+            is_true(isinstance(col.type, String))
 
     @testing.requires.python38
     def test_typing_literal_identity(self, decl_base):
@@ -1224,6 +1372,33 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         eq_(MyClass.__table__.c.data_four.type.length, 150)
         is_false(MyClass.__table__.c.data_four.nullable)
 
+    def test_newtype_missing_from_map(self, decl_base):
+        # anno only: global str50
+
+        str50 = NewType("str50", str)
+
+        if compat.py310:
+            text = ".*str50"
+        else:
+            # NewTypes before 3.10 had a very bad repr
+            # <function NewType.<locals>.new_type at 0x...>
+            text = ".*NewType.*"
+
+        with expect_deprecated(
+            f"Matching the provided NewType '{text}' on its "
+            "resolved value without matching it in the "
+            "type_annotation_map is deprecated; add this type to the "
+            "type_annotation_map to allow it to match explicitly.",
+        ):
+
+            class MyClass(decl_base):
+                __tablename__ = "my_table"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+                data_one: Mapped[str50]
+
+        is_true(isinstance(MyClass.data_one.type, String))
+
     def test_extract_base_type_from_pep593(
         self, decl_base: Type[DeclarativeBase]
     ):
@@ -1715,39 +1890,40 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             is_(getattr(Element.__table__.c.data, paramname), override_value)
 
-    @testing.variation("union", ["union", "pep604"])
-    @testing.variation("typealias", ["legacy", "pep695"])
-    def test_unions(self, union, typealias):
+    @testing.variation(
+        "union",
+        [
+            "union",
+            ("pep604", requires.python310),
+            "union_null",
+            ("pep604_null", requires.python310),
+        ],
+    )
+    def test_unions(self, union):
+        # anno only: global UnionType
         our_type = Numeric(10, 2)
 
         if union.union:
             UnionType = Union[float, Decimal]
+        elif union.union_null:
+            UnionType = Union[float, Decimal, None]
         elif union.pep604:
-            if not compat.py310:
-                skip_test("Required Python 3.10")
             UnionType = float | Decimal
+        elif union.pep604_null:
+            UnionType = float | Decimal | None
         else:
             union.fail()
 
-        if typealias.legacy:
-            UnionTypeAlias = UnionType
-        elif typealias.pep695:
-            # same as type UnionTypeAlias = UnionType
-            UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType)
-        else:
-            typealias.fail()
-
         class Base(DeclarativeBase):
-            type_annotation_map = {UnionTypeAlias: our_type}
+            type_annotation_map = {UnionType: our_type}
 
         class User(Base):
             __tablename__ = "users"
-            __table__: Table
 
             id: Mapped[int] = mapped_column(primary_key=True)
 
-            data: Mapped[Union[float, Decimal]] = mapped_column()
-            reverse_data: Mapped[Union[Decimal, float]] = mapped_column()
+            data: Mapped[Union[float, Decimal]]
+            reverse_data: Mapped[Union[Decimal, float]]
 
             optional_data: Mapped[Optional[Union[float, Decimal]]] = (
                 mapped_column()
@@ -1764,6 +1940,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 mapped_column()
             )
 
+            refer_union: Mapped[UnionType]
+            refer_union_optional: Mapped[Optional[UnionType]]
+
             float_data: Mapped[float] = mapped_column()
             decimal_data: Mapped[Decimal] = mapped_column()
 
@@ -1779,65 +1958,54 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                     mapped_column()
                 )
 
-            if compat.py312:
-                MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal)
-                pep695_data: Mapped[MyTypeAlias] = mapped_column()
-
-        is_(User.__table__.c.data.type, our_type)
-        is_false(User.__table__.c.data.nullable)
-        is_(User.__table__.c.reverse_data.type, our_type)
-        is_(User.__table__.c.optional_data.type, our_type)
-        is_true(User.__table__.c.optional_data.nullable)
+        info = [
+            ("data", False),
+            ("reverse_data", False),
+            ("optional_data", True),
+            ("reverse_optional_data", True),
+            ("reverse_u_optional_data", True),
+            ("refer_union", "null" in union.name),
+            ("refer_union_optional", True),
+        ]
+        if compat.py310:
+            info += [
+                ("pep604_data", False),
+                ("pep604_reverse", False),
+                ("pep604_optional", True),
+                ("pep604_data_fwd", False),
+                ("pep604_reverse_fwd", False),
+                ("pep604_optional_fwd", True),
+            ]
 
-        is_(User.__table__.c.reverse_optional_data.type, our_type)
-        is_(User.__table__.c.reverse_u_optional_data.type, our_type)
-        is_true(User.__table__.c.reverse_optional_data.nullable)
-        is_true(User.__table__.c.reverse_u_optional_data.nullable)
+        for name, nullable in info:
+            col = User.__table__.c[name]
+            is_(col.type, our_type, name)
+            is_(col.nullable, nullable, name)
 
         is_true(isinstance(User.__table__.c.float_data.type, Float))
+        ne_(User.__table__.c.float_data.type, our_type)
 
-        is_not(User.__table__.c.decimal_data.type, our_type)
+        is_true(isinstance(User.__table__.c.decimal_data.type, Numeric))
+        ne_(User.__table__.c.decimal_data.type, our_type)
 
-        if compat.py310:
-            for suffix in ("", "_fwd"):
-                data_col = User.__table__.c[f"pep604_data{suffix}"]
-                reverse_col = User.__table__.c[f"pep604_reverse{suffix}"]
-                optional_col = User.__table__.c[f"pep604_optional{suffix}"]
-                is_(data_col.type, our_type)
-                is_false(data_col.nullable)
-                is_(reverse_col.type, our_type)
-                is_false(reverse_col.nullable)
-                is_(optional_col.type, our_type)
-                is_true(optional_col.nullable)
-
-        if compat.py312:
-            is_(User.__table__.c.pep695_data.type, our_type)
-
-    @testing.variation("union", ["union", "pep604"])
+    @testing.variation(
+        "union",
+        [
+            "union",
+            ("pep604", requires.python310),
+            ("pep695", requires.python312),
+        ],
+    )
     def test_optional_in_annotation_map(self, union):
-        """SQLAlchemy's behaviour is clear: an optional type means the column
-        is inferred as nullable. Some types which a user may want to put in the
-        type annotation map are already optional. JSON is a good example
-        because without any constraint, the type can be None via JSON null or
-        SQL NULL.
-
-        By permitting optional types in the type annotation map, everything
-        just works, and mapped_column(nullable=False) is available if desired.
-
-        See issue #11370
-        """
+        """See issue #11370"""
 
         class Base(DeclarativeBase):
             if union.union:
-                type_annotation_map = {
-                    _Json: JSON,
-                }
+                type_annotation_map = {_Json: JSON}
             elif union.pep604:
-                if not compat.py310:
-                    skip_test("Requires Python 3.10+")
-                type_annotation_map = {
-                    _JsonPep604: JSON,
-                }
+                type_annotation_map = {_JsonPep604: JSON}
+            elif union.pep695:
+                type_annotation_map = {_JsonPep695: JSON}  # noqa: F821
             else:
                 union.fail()
 
@@ -1849,10 +2017,13 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 json1: Mapped[_Json]
                 json2: Mapped[_Json] = mapped_column(nullable=False)
             elif union.pep604:
-                if not compat.py310:
-                    skip_test("Requires Python 3.10+")
                 json1: Mapped[_JsonPep604]
                 json2: Mapped[_JsonPep604] = mapped_column(nullable=False)
+            elif union.pep695:
+                json1: Mapped[_JsonPep695]  # noqa: F821
+                json2: Mapped[_JsonPep695] = mapped_column(  # noqa: F821
+                    nullable=False
+                )
             else:
                 union.fail()
 
index 3a06ac9f2735b48d047a18c69b0a7c9c6b51af74..a3b6965c862a4a25137c4f604db2acb03192b625 100644 (file)
@@ -12,6 +12,7 @@ that it extracts from the documentation.
 from argparse import ArgumentParser
 from argparse import RawDescriptionHelpFormatter
 from collections.abc import Iterator
+import dataclasses
 from functools import partial
 from itertools import chain
 from pathlib import Path
@@ -33,6 +34,8 @@ ignore_paths = (
     re.compile(r"build"),
 )
 
+CUSTOM_TARGET_VERSIONS = {"declarative_tables.rst": "PY312"}
+
 
 class BlockLine(NamedTuple):
     line: str
@@ -66,6 +69,12 @@ def _format_block(
         code = "\n".join(l.code for l in input_block)
 
     mode = PYTHON_BLACK_MODE if is_python_file else RST_BLACK_MODE
+    custom_target = CUSTOM_TARGET_VERSIONS.get(Path(file).name)
+    if custom_target:
+        mode = dataclasses.replace(
+            mode, target_versions={TargetVersion[custom_target]}
+        )
+
     try:
         formatted = format_str(code, mode=mode)
     except Exception as e: