From: Frazer McLean Date: Thu, 3 Oct 2024 22:21:12 +0000 (-0400) Subject: dont match partial types in type_annotation_map X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=40c30ec44616223216737327f97bac66a13eedee;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git dont match partial types in type_annotation_map Fixed issue regarding ``Union`` types that would be present in the :paramref:`_orm.registry.type_annotation_map` of a :class:`_orm.registry` or declarative base class, where a ``Mapped[]`` element that included one of the subtypes present in that ``Union`` would be matched to that entry, potentially ignoring other entries that matched exactly. The correct behavior now takes place such that an entry should only match in ``type_annotation_map`` exactly, as a ``Union`` type is a self-contained type. For example, an attribute with ``Mapped[float]`` would previously match to a ``type_annotation_map`` entry ``Union[float, Decimal]``; this will no longer match and will now only match to an entry that states ``float``. Pull request courtesy Frazer McLean. Fixes #11370 Closes: #11942 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11942 Pull-request-sha: 21a3d1971a04e117a557f6e6bac77bce9f6bb0a9 Change-Id: I3467be00f8fa8bd011dd4805a77a3b80ff74a215 --- diff --git a/doc/build/changelog/unreleased_20/11370.rst b/doc/build/changelog/unreleased_20/11370.rst new file mode 100644 index 0000000000..56e85531fc --- /dev/null +++ b/doc/build/changelog/unreleased_20/11370.rst @@ -0,0 +1,15 @@ +.. change:: + :tags: bug, orm + :tickets: 11370 + + Fixed issue regarding ``Union`` types that would be present in the + :paramref:`_orm.registry.type_annotation_map` of a :class:`_orm.registry` + or declarative base class, where a ``Mapped[]`` element that included one + of the subtypes present in that ``Union`` would be matched to that entry, + potentially ignoring other entries that matched exactly. The correct + behavior now takes place such that an entry should only match in + ``type_annotation_map`` exactly, as a ``Union`` type is a self-contained + type. For example, an attribute with ``Mapped[float]`` would previously + match to a ``type_annotation_map`` entry ``Union[float, Decimal]``; this + will no longer match and will now only match to an entry that states + ``float``. Pull request courtesy Frazer McLean. diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 71270c6b4e..6ad3176195 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -73,6 +73,7 @@ from ..util import hybridmethod from ..util import hybridproperty from ..util import typing as compat_typing from ..util.typing import CallableReference +from ..util.typing import de_optionalize_union_types from ..util.typing import flatten_newtype from ..util.typing import is_generic from ..util.typing import is_literal @@ -1225,11 +1226,8 @@ class registry: self.type_annotation_map.update( { - sub_type: sqltype + de_optionalize_union_types(typ): sqltype for typ, sqltype in type_annotation_map.items() - for sub_type in compat_typing.expand_unions( - typ, include_union=True, discard_none=True - ) } ) diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 7510e7a387..be2f101352 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -421,6 +421,9 @@ def de_optionalize_union_types( """ + while is_pep695(type_): + type_ = type_.__value__ + if is_fwd_ref(type_): return de_optionalize_fwd_ref_union_types(type_) @@ -477,26 +480,6 @@ def make_union_type(*types: _AnnotationScanType) -> Type[Any]: return cast(Any, Union).__getitem__(types) # type: ignore -def expand_unions( - type_: Type[Any], include_union: bool = False, discard_none: bool = False -) -> Tuple[Type[Any], ...]: - """Return a type as a tuple of individual types, expanding for - ``Union`` types.""" - - if is_union(type_): - typ = set(type_.__args__) - - if discard_none: - typ.discard(NoneType) - - if include_union: - return (type_,) + tuple(typ) # type: ignore - else: - return tuple(typ) # type: ignore - else: - return (type_,) - - def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]: return is_origin_of( type_, @@ -511,7 +494,7 @@ def is_optional_union(type_: Any) -> bool: def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]: - return is_origin_of(type_, "Union") + return is_origin_of(type_, "Union", "UnionType") def is_origin_of_cls( diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 77ab9ff222..0f074e937c 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -4,6 +4,9 @@ import inspect from pathlib import Path import pickle import sys +import typing + +import typing_extensions from sqlalchemy import exc from sqlalchemy import sql @@ -38,6 +41,7 @@ from sqlalchemy.util import preloaded from sqlalchemy.util import WeakSequence from sqlalchemy.util._collections import merge_lists_w_ordering from sqlalchemy.util._has_cython import _all_cython_modules +from sqlalchemy.util.typing import is_union class WeakSequenceTest(fixtures.TestBase): @@ -3657,3 +3661,11 @@ class CyExtensionTest(fixtures.TestBase): print(expected) print(setup_modules) eq_(setup_modules, expected) + + +class TypingTest(fixtures.TestBase): + def test_is_union(self): + assert is_union(typing.Union[str, int]) + assert is_union(typing_extensions.Union[str, int]) + if compat.py310: + assert is_union(str | int) diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index ca2e01242f..2aad4dc330 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -34,6 +34,7 @@ import typing_extensions from typing_extensions import get_args as get_args from typing_extensions import Literal as Literal from typing_extensions import TypeAlias as TypeAlias +from typing_extensions import TypeAliasType from sqlalchemy import BIGINT from sqlalchemy import BigInteger @@ -41,6 +42,7 @@ from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import exc from sqlalchemy import exc as sa_exc +from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Identity @@ -94,6 +96,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true +from sqlalchemy.testing import skip_test from sqlalchemy.testing import Variation from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat @@ -123,6 +126,19 @@ _Recursive695_2: TypeAlias = _Recursive695_1 _TypingLiteral = typing.Literal["a", "b"] _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] +_JsonPrimitive: TypeAlias = Union[str, int, float, bool, None] +_JsonObject: TypeAlias = Dict[str, "_Json"] +_JsonArray: TypeAlias = List["_Json"] +_Json: TypeAlias = Union[_JsonObject, _JsonArray, _JsonPrimitive] + +if compat.py310: + _JsonPrimitivePep604: TypeAlias = str | int | float | bool | None + _JsonObjectPep604: TypeAlias = dict[str, "_JsonPep604"] + _JsonArrayPep604: TypeAlias = list["_JsonPep604"] + _JsonPep604: TypeAlias = ( + _JsonObjectPep604 | _JsonArrayPep604 | _JsonPrimitivePep604 + ) + if compat.py312: exec( """ @@ -1706,11 +1722,30 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): else: is_(getattr(Element.__table__.c.data, paramname), override_value) - def test_unions(self): + @testing.variation("union", ["union", "pep604"]) + @testing.variation("typealias", ["legacy", "pep695"]) + def test_unions(self, union, typealias): our_type = Numeric(10, 2) + if union.union: + UnionType = Union[float, Decimal] + elif union.pep604: + if not compat.py310: + skip_test("Required Python 3.10") + UnionType = float | Decimal + else: + union.fail() + + if typealias.legacy: + UnionTypeAlias = UnionType + elif typealias.pep695: + # same as type UnionTypeAlias = UnionType + UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType) + else: + typealias.fail() + class Base(DeclarativeBase): - type_annotation_map = {Union[float, Decimal]: our_type} + type_annotation_map = {UnionTypeAlias: our_type} class User(Base): __tablename__ = "users" @@ -1751,6 +1786,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): mapped_column() ) + if compat.py312: + MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal) + pep695_data: Mapped[MyTypeAlias] = mapped_column() + is_(User.__table__.c.data.type, our_type) is_false(User.__table__.c.data.nullable) is_(User.__table__.c.reverse_data.type, our_type) @@ -1762,8 +1801,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(User.__table__.c.reverse_optional_data.nullable) is_true(User.__table__.c.reverse_u_optional_data.nullable) - is_(User.__table__.c.float_data.type, our_type) - is_(User.__table__.c.decimal_data.type, our_type) + is_true(isinstance(User.__table__.c.float_data.type, Float)) + is_true(isinstance(User.__table__.c.float_data.type, Numeric)) + is_not(User.__table__.c.decimal_data.type, our_type) if compat.py310: for suffix in ("", "_fwd"): @@ -1777,6 +1817,57 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_(optional_col.type, our_type) is_true(optional_col.nullable) + if compat.py312: + is_(User.__table__.c.pep695_data.type, our_type) + + @testing.variation("union", ["union", "pep604"]) + def test_optional_in_annotation_map(self, union): + """SQLAlchemy's behaviour is clear: an optional type means the column + is inferred as nullable. Some types which a user may want to put in the + type annotation map are already optional. JSON is a good example + because without any constraint, the type can be None via JSON null or + SQL NULL. + + By permitting optional types in the type annotation map, everything + just works, and mapped_column(nullable=False) is available if desired. + + See issue #11370 + """ + + class Base(DeclarativeBase): + if union.union: + type_annotation_map = { + _Json: JSON, + } + elif union.pep604: + if not compat.py310: + skip_test("Requires Python 3.10+") + type_annotation_map = { + _JsonPep604: JSON, + } + else: + union.fail() + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + if union.union: + json1: Mapped[_Json] + json2: Mapped[_Json] = mapped_column(nullable=False) + elif union.pep604: + if not compat.py310: + skip_test("Requires Python 3.10+") + json1: Mapped[_JsonPep604] + json2: Mapped[_JsonPep604] = mapped_column(nullable=False) + else: + union.fail() + + is_(A.__table__.c.json1.type._type_affinity, JSON) + is_(A.__table__.c.json2.type._type_affinity, JSON) + is_true(A.__table__.c.json1.nullable) + is_false(A.__table__.c.json2.nullable) + @testing.combinations( ("not_optional",), ("optional",), diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 6d48769264..d5a5c18c3e 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -25,6 +25,7 @@ import typing_extensions from typing_extensions import get_args as get_args from typing_extensions import Literal as Literal from typing_extensions import TypeAlias as TypeAlias +from typing_extensions import TypeAliasType from sqlalchemy import BIGINT from sqlalchemy import BigInteger @@ -32,6 +33,7 @@ from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import exc from sqlalchemy import exc as sa_exc +from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Identity @@ -85,6 +87,7 @@ from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true +from sqlalchemy.testing import skip_test from sqlalchemy.testing import Variation from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.util import compat @@ -114,6 +117,19 @@ _Recursive695_2: TypeAlias = _Recursive695_1 _TypingLiteral = typing.Literal["a", "b"] _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] +_JsonPrimitive: TypeAlias = Union[str, int, float, bool, None] +_JsonObject: TypeAlias = Dict[str, "_Json"] +_JsonArray: TypeAlias = List["_Json"] +_Json: TypeAlias = Union[_JsonObject, _JsonArray, _JsonPrimitive] + +if compat.py310: + _JsonPrimitivePep604: TypeAlias = str | int | float | bool | None + _JsonObjectPep604: TypeAlias = dict[str, "_JsonPep604"] + _JsonArrayPep604: TypeAlias = list["_JsonPep604"] + _JsonPep604: TypeAlias = ( + _JsonObjectPep604 | _JsonArrayPep604 | _JsonPrimitivePep604 + ) + if compat.py312: exec( """ @@ -1697,11 +1713,30 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): else: is_(getattr(Element.__table__.c.data, paramname), override_value) - def test_unions(self): + @testing.variation("union", ["union", "pep604"]) + @testing.variation("typealias", ["legacy", "pep695"]) + def test_unions(self, union, typealias): our_type = Numeric(10, 2) + if union.union: + UnionType = Union[float, Decimal] + elif union.pep604: + if not compat.py310: + skip_test("Required Python 3.10") + UnionType = float | Decimal + else: + union.fail() + + if typealias.legacy: + UnionTypeAlias = UnionType + elif typealias.pep695: + # same as type UnionTypeAlias = UnionType + UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType) + else: + typealias.fail() + class Base(DeclarativeBase): - type_annotation_map = {Union[float, Decimal]: our_type} + type_annotation_map = {UnionTypeAlias: our_type} class User(Base): __tablename__ = "users" @@ -1742,6 +1777,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): mapped_column() ) + if compat.py312: + MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal) + pep695_data: Mapped[MyTypeAlias] = mapped_column() + is_(User.__table__.c.data.type, our_type) is_false(User.__table__.c.data.nullable) is_(User.__table__.c.reverse_data.type, our_type) @@ -1753,8 +1792,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(User.__table__.c.reverse_optional_data.nullable) is_true(User.__table__.c.reverse_u_optional_data.nullable) - is_(User.__table__.c.float_data.type, our_type) - is_(User.__table__.c.decimal_data.type, our_type) + is_true(isinstance(User.__table__.c.float_data.type, Float)) + is_true(isinstance(User.__table__.c.float_data.type, Numeric)) + is_not(User.__table__.c.decimal_data.type, our_type) if compat.py310: for suffix in ("", "_fwd"): @@ -1768,6 +1808,57 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_(optional_col.type, our_type) is_true(optional_col.nullable) + if compat.py312: + is_(User.__table__.c.pep695_data.type, our_type) + + @testing.variation("union", ["union", "pep604"]) + def test_optional_in_annotation_map(self, union): + """SQLAlchemy's behaviour is clear: an optional type means the column + is inferred as nullable. Some types which a user may want to put in the + type annotation map are already optional. JSON is a good example + because without any constraint, the type can be None via JSON null or + SQL NULL. + + By permitting optional types in the type annotation map, everything + just works, and mapped_column(nullable=False) is available if desired. + + See issue #11370 + """ + + class Base(DeclarativeBase): + if union.union: + type_annotation_map = { + _Json: JSON, + } + elif union.pep604: + if not compat.py310: + skip_test("Requires Python 3.10+") + type_annotation_map = { + _JsonPep604: JSON, + } + else: + union.fail() + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + if union.union: + json1: Mapped[_Json] + json2: Mapped[_Json] = mapped_column(nullable=False) + elif union.pep604: + if not compat.py310: + skip_test("Requires Python 3.10+") + json1: Mapped[_JsonPep604] + json2: Mapped[_JsonPep604] = mapped_column(nullable=False) + else: + union.fail() + + is_(A.__table__.c.json1.type._type_affinity, JSON) + is_(A.__table__.c.json2.type._type_affinity, JSON) + is_true(A.__table__.c.json1.nullable) + is_false(A.__table__.c.json2.nullable) + @testing.combinations( ("not_optional",), ("optional",),