From 8a45f13ef66628c5d5ff30bed30c3a62874f041e Mon Sep 17 00:00:00 2001 From: Frazer McLean Date: Thu, 3 Oct 2024 18:21:12 -0400 Subject: [PATCH] dont match partial types in type_annotation_map Fixed issue regarding ``Union`` types that would be present in the :paramref:`_orm.registry.type_annotation_map` of a :class:`_orm.registry` or declarative base class, where a ``Mapped[]`` element that included one of the subtypes present in that ``Union`` would be matched to that entry, potentially ignoring other entries that matched exactly. The correct behavior now takes place such that an entry should only match in ``type_annotation_map`` exactly, as a ``Union`` type is a self-contained type. For example, an attribute with ``Mapped[float]`` would previously match to a ``type_annotation_map`` entry ``Union[float, Decimal]``; this will no longer match and will now only match to an entry that states ``float``. Pull request courtesy Frazer McLean. Fixes #11370 Closes: #11942 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11942 Pull-request-sha: 21a3d1971a04e117a557f6e6bac77bce9f6bb0a9 Change-Id: I3467be00f8fa8bd011dd4805a77a3b80ff74a215 (cherry picked from commit 40c30ec44616223216737327f97bac66a13eedee) --- doc/build/changelog/unreleased_20/11370.rst | 15 +++ lib/sqlalchemy/orm/decl_api.py | 6 +- lib/sqlalchemy/util/typing.py | 25 +---- test/base/test_utils.py | 12 +++ .../test_tm_future_annotations_sync.py | 99 ++++++++++++++++++- test/orm/declarative/test_typed_mapping.py | 99 ++++++++++++++++++- 6 files changed, 223 insertions(+), 33 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/11370.rst diff --git a/doc/build/changelog/unreleased_20/11370.rst b/doc/build/changelog/unreleased_20/11370.rst new file mode 100644 index 0000000000..56e85531fc --- /dev/null +++ b/doc/build/changelog/unreleased_20/11370.rst @@ -0,0 +1,15 @@ +.. change:: + :tags: bug, orm + :tickets: 11370 + + Fixed issue regarding ``Union`` types that would be present in the + :paramref:`_orm.registry.type_annotation_map` of a :class:`_orm.registry` + or declarative base class, where a ``Mapped[]`` element that included one + of the subtypes present in that ``Union`` would be matched to that entry, + potentially ignoring other entries that matched exactly. The correct + behavior now takes place such that an entry should only match in + ``type_annotation_map`` exactly, as a ``Union`` type is a self-contained + type. For example, an attribute with ``Mapped[float]`` would previously + match to a ``type_annotation_map`` entry ``Union[float, Decimal]``; this + will no longer match and will now only match to an entry that states + ``float``. Pull request courtesy Frazer McLean. diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 311a9bd4a5..718cf72516 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -73,6 +73,7 @@ from ..util import hybridmethod from ..util import hybridproperty from ..util import typing as compat_typing from ..util.typing import CallableReference +from ..util.typing import de_optionalize_union_types from ..util.typing import flatten_newtype from ..util.typing import is_generic from ..util.typing import is_literal @@ -1225,11 +1226,8 @@ class registry: self.type_annotation_map.update( { - sub_type: sqltype + de_optionalize_union_types(typ): sqltype for typ, sqltype in type_annotation_map.items() - for sub_type in compat_typing.expand_unions( - typ, include_union=True, discard_none=True - ) } ) diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index a3df977705..bd1ebd4c01 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -422,6 +422,9 @@ def de_optionalize_union_types( """ + while is_pep695(type_): + type_ = type_.__value__ + if is_fwd_ref(type_): return de_optionalize_fwd_ref_union_types(type_) @@ -478,26 +481,6 @@ def make_union_type(*types: _AnnotationScanType) -> Type[Any]: return cast(Any, Union).__getitem__(types) # type: ignore -def expand_unions( - type_: Type[Any], include_union: bool = False, discard_none: bool = False -) -> Tuple[Type[Any], ...]: - """Return a type as a tuple of individual types, expanding for - ``Union`` types.""" - - if is_union(type_): - typ = set(type_.__args__) - - if discard_none: - typ.discard(NoneType) - - if include_union: - return (type_,) + tuple(typ) # type: ignore - else: - return tuple(typ) # type: ignore - else: - return (type_,) - - def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]: return is_origin_of( type_, @@ -512,7 +495,7 @@ def is_optional_union(type_: Any) -> bool: def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]: - return is_origin_of(type_, "Union") + return is_origin_of(type_, "Union", "UnionType") def is_origin_of_cls( diff --git a/test/base/test_utils.py b/test/base/test_utils.py index de8712c852..85c419e94e 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -4,6 +4,9 @@ import inspect from pathlib import Path import pickle import sys +import typing + +import typing_extensions from sqlalchemy import exc from sqlalchemy import sql @@ -39,6 +42,7 @@ from sqlalchemy.util import WeakSequence from sqlalchemy.util._collections import merge_lists_w_ordering from sqlalchemy.util._has_cy import _import_cy_extensions from sqlalchemy.util._has_cy import HAS_CYEXTENSION +from sqlalchemy.util.typing import is_union class WeakSequenceTest(fixtures.TestBase): @@ -3630,3 +3634,11 @@ class CyExtensionTest(fixtures.TestBase): for f in cython_files } eq_({m.__name__ for m in ext}, set(names)) + + +class TypingTest(fixtures.TestBase): + def test_is_union(self): + assert is_union(typing.Union[str, int]) + assert is_union(typing_extensions.Union[str, int]) + if compat.py310: + assert is_union(str | int) diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index 2aa8f0f0b0..6bf7d02c56 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -33,6 +33,7 @@ import typing_extensions from typing_extensions import get_args as get_args from typing_extensions import Literal as Literal from typing_extensions import TypeAlias as TypeAlias +from typing_extensions import TypeAliasType from typing_extensions import TypedDict from sqlalchemy import BIGINT @@ -41,6 +42,7 @@ from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import exc from sqlalchemy import exc as sa_exc +from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Identity @@ -94,6 +96,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true +from sqlalchemy.testing import skip_test from sqlalchemy.testing import Variation from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat @@ -124,6 +127,19 @@ if compat.py38: _TypingLiteral = typing.Literal["a", "b"] _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] +_JsonPrimitive: TypeAlias = Union[str, int, float, bool, None] +_JsonObject: TypeAlias = Dict[str, "_Json"] +_JsonArray: TypeAlias = List["_Json"] +_Json: TypeAlias = Union[_JsonObject, _JsonArray, _JsonPrimitive] + +if compat.py310: + _JsonPrimitivePep604: TypeAlias = str | int | float | bool | None + _JsonObjectPep604: TypeAlias = dict[str, "_JsonPep604"] + _JsonArrayPep604: TypeAlias = list["_JsonPep604"] + _JsonPep604: TypeAlias = ( + _JsonObjectPep604 | _JsonArrayPep604 | _JsonPrimitivePep604 + ) + if compat.py312: exec( """ @@ -1708,11 +1724,30 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): else: is_(getattr(Element.__table__.c.data, paramname), override_value) - def test_unions(self): + @testing.variation("union", ["union", "pep604"]) + @testing.variation("typealias", ["legacy", "pep695"]) + def test_unions(self, union, typealias): our_type = Numeric(10, 2) + if union.union: + UnionType = Union[float, Decimal] + elif union.pep604: + if not compat.py310: + skip_test("Required Python 3.10") + UnionType = float | Decimal + else: + union.fail() + + if typealias.legacy: + UnionTypeAlias = UnionType + elif typealias.pep695: + # same as type UnionTypeAlias = UnionType + UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType) + else: + typealias.fail() + class Base(DeclarativeBase): - type_annotation_map = {Union[float, Decimal]: our_type} + type_annotation_map = {UnionTypeAlias: our_type} class User(Base): __tablename__ = "users" @@ -1753,6 +1788,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): mapped_column() ) + if compat.py312: + MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal) + pep695_data: Mapped[MyTypeAlias] = mapped_column() + is_(User.__table__.c.data.type, our_type) is_false(User.__table__.c.data.nullable) is_(User.__table__.c.reverse_data.type, our_type) @@ -1764,8 +1803,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(User.__table__.c.reverse_optional_data.nullable) is_true(User.__table__.c.reverse_u_optional_data.nullable) - is_(User.__table__.c.float_data.type, our_type) - is_(User.__table__.c.decimal_data.type, our_type) + is_true(isinstance(User.__table__.c.float_data.type, Float)) + is_true(isinstance(User.__table__.c.float_data.type, Numeric)) + is_not(User.__table__.c.decimal_data.type, our_type) if compat.py310: for suffix in ("", "_fwd"): @@ -1779,6 +1819,57 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_(optional_col.type, our_type) is_true(optional_col.nullable) + if compat.py312: + is_(User.__table__.c.pep695_data.type, our_type) + + @testing.variation("union", ["union", "pep604"]) + def test_optional_in_annotation_map(self, union): + """SQLAlchemy's behaviour is clear: an optional type means the column + is inferred as nullable. Some types which a user may want to put in the + type annotation map are already optional. JSON is a good example + because without any constraint, the type can be None via JSON null or + SQL NULL. + + By permitting optional types in the type annotation map, everything + just works, and mapped_column(nullable=False) is available if desired. + + See issue #11370 + """ + + class Base(DeclarativeBase): + if union.union: + type_annotation_map = { + _Json: JSON, + } + elif union.pep604: + if not compat.py310: + skip_test("Requires Python 3.10+") + type_annotation_map = { + _JsonPep604: JSON, + } + else: + union.fail() + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + if union.union: + json1: Mapped[_Json] + json2: Mapped[_Json] = mapped_column(nullable=False) + elif union.pep604: + if not compat.py310: + skip_test("Requires Python 3.10+") + json1: Mapped[_JsonPep604] + json2: Mapped[_JsonPep604] = mapped_column(nullable=False) + else: + union.fail() + + is_(A.__table__.c.json1.type._type_affinity, JSON) + is_(A.__table__.c.json2.type._type_affinity, JSON) + is_true(A.__table__.c.json1.nullable) + is_false(A.__table__.c.json2.nullable) + @testing.combinations( ("not_optional",), ("optional",), diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index b50573fa12..929041ccfb 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -24,6 +24,7 @@ import typing_extensions from typing_extensions import get_args as get_args from typing_extensions import Literal as Literal from typing_extensions import TypeAlias as TypeAlias +from typing_extensions import TypeAliasType from typing_extensions import TypedDict from sqlalchemy import BIGINT @@ -32,6 +33,7 @@ from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import exc from sqlalchemy import exc as sa_exc +from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Identity @@ -85,6 +87,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true +from sqlalchemy.testing import skip_test from sqlalchemy.testing import Variation from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat @@ -115,6 +118,19 @@ if compat.py38: _TypingLiteral = typing.Literal["a", "b"] _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] +_JsonPrimitive: TypeAlias = Union[str, int, float, bool, None] +_JsonObject: TypeAlias = Dict[str, "_Json"] +_JsonArray: TypeAlias = List["_Json"] +_Json: TypeAlias = Union[_JsonObject, _JsonArray, _JsonPrimitive] + +if compat.py310: + _JsonPrimitivePep604: TypeAlias = str | int | float | bool | None + _JsonObjectPep604: TypeAlias = dict[str, "_JsonPep604"] + _JsonArrayPep604: TypeAlias = list["_JsonPep604"] + _JsonPep604: TypeAlias = ( + _JsonObjectPep604 | _JsonArrayPep604 | _JsonPrimitivePep604 + ) + if compat.py312: exec( """ @@ -1699,11 +1715,30 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): else: is_(getattr(Element.__table__.c.data, paramname), override_value) - def test_unions(self): + @testing.variation("union", ["union", "pep604"]) + @testing.variation("typealias", ["legacy", "pep695"]) + def test_unions(self, union, typealias): our_type = Numeric(10, 2) + if union.union: + UnionType = Union[float, Decimal] + elif union.pep604: + if not compat.py310: + skip_test("Required Python 3.10") + UnionType = float | Decimal + else: + union.fail() + + if typealias.legacy: + UnionTypeAlias = UnionType + elif typealias.pep695: + # same as type UnionTypeAlias = UnionType + UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType) + else: + typealias.fail() + class Base(DeclarativeBase): - type_annotation_map = {Union[float, Decimal]: our_type} + type_annotation_map = {UnionTypeAlias: our_type} class User(Base): __tablename__ = "users" @@ -1744,6 +1779,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): mapped_column() ) + if compat.py312: + MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal) + pep695_data: Mapped[MyTypeAlias] = mapped_column() + is_(User.__table__.c.data.type, our_type) is_false(User.__table__.c.data.nullable) is_(User.__table__.c.reverse_data.type, our_type) @@ -1755,8 +1794,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(User.__table__.c.reverse_optional_data.nullable) is_true(User.__table__.c.reverse_u_optional_data.nullable) - is_(User.__table__.c.float_data.type, our_type) - is_(User.__table__.c.decimal_data.type, our_type) + is_true(isinstance(User.__table__.c.float_data.type, Float)) + is_true(isinstance(User.__table__.c.float_data.type, Numeric)) + is_not(User.__table__.c.decimal_data.type, our_type) if compat.py310: for suffix in ("", "_fwd"): @@ -1770,6 +1810,57 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_(optional_col.type, our_type) is_true(optional_col.nullable) + if compat.py312: + is_(User.__table__.c.pep695_data.type, our_type) + + @testing.variation("union", ["union", "pep604"]) + def test_optional_in_annotation_map(self, union): + """SQLAlchemy's behaviour is clear: an optional type means the column + is inferred as nullable. Some types which a user may want to put in the + type annotation map are already optional. JSON is a good example + because without any constraint, the type can be None via JSON null or + SQL NULL. + + By permitting optional types in the type annotation map, everything + just works, and mapped_column(nullable=False) is available if desired. + + See issue #11370 + """ + + class Base(DeclarativeBase): + if union.union: + type_annotation_map = { + _Json: JSON, + } + elif union.pep604: + if not compat.py310: + skip_test("Requires Python 3.10+") + type_annotation_map = { + _JsonPep604: JSON, + } + else: + union.fail() + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + if union.union: + json1: Mapped[_Json] + json2: Mapped[_Json] = mapped_column(nullable=False) + elif union.pep604: + if not compat.py310: + skip_test("Requires Python 3.10+") + json1: Mapped[_JsonPep604] + json2: Mapped[_JsonPep604] = mapped_column(nullable=False) + else: + union.fail() + + is_(A.__table__.c.json1.type._type_affinity, JSON) + is_(A.__table__.c.json2.type._type_affinity, JSON) + is_true(A.__table__.c.json1.nullable) + is_false(A.__table__.c.json2.nullable) + @testing.combinations( ("not_optional",), ("optional",), -- 2.47.3