From fcd298e1afe9b309de34d28b35e4debc3940d6b9 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 8 Sep 2022 13:19:08 -0400 Subject: [PATCH] additional de-stringify pass for unions the change in c3cfee5b00a40790c18d took out a pass for de-stringify that broke some un-tested cases for Optional with future annotations mode. Adding tests for this revealed that this was a subset of a more general case where Union is presented with ForwardRefs inside of it matching up within the type map, which wasn't working before either, fixed that as well with an additional de-stringify for elements within the Union. Fixes: #8478 Change-Id: I8804cf6c67f14d10804584e1cddd2cfaa2376654 --- lib/sqlalchemy/orm/properties.py | 6 ++ lib/sqlalchemy/util/typing.py | 17 +++- .../declarative/test_tm_future_annotations.py | 82 ++++++++++++++++++- 3 files changed, 102 insertions(+), 3 deletions(-) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 7d71756780..3d9fe578d8 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -51,9 +51,11 @@ from ..sql.schema import Column from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation +from ..util.typing import de_stringify_union_elements from ..util.typing import is_fwd_ref from ..util.typing import is_optional_union from ..util.typing import is_pep593 +from ..util.typing import is_union from ..util.typing import Self from ..util.typing import typing_get_args @@ -655,6 +657,9 @@ class MappedColumn( if is_fwd_ref(argument): argument = de_stringify_annotation(cls, argument) + if is_union(argument): + argument = de_stringify_union_elements(cls, argument) + nullable = is_optional_union(argument) if not self._has_nullable: @@ -690,6 +695,7 @@ class MappedColumn( checks = (our_type,) for check_type in checks: + if registry.type_annotation_map: new_sqltype = registry.type_annotation_map.get(check_type) if new_sqltype is None: diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 85c1bae72b..a0d59a6305 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -120,6 +120,19 @@ def de_stringify_annotation( return annotation # type: ignore +def de_stringify_union_elements( + cls: Type[Any], + annotation: _AnnotationScanType, + str_cleanup_fn: Optional[Callable[[str], str]] = None, +) -> Type[Any]: + return make_union_type( + *[ + de_stringify_annotation(cls, anno, str_cleanup_fn) + for anno in annotation.__args__ # type: ignore + ] + ) + + def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: return type_ is not None and typing_get_origin(type_) is Annotated @@ -186,7 +199,7 @@ def expand_unions( return (type_,) -def is_optional(type_): +def is_optional(type_: Any) -> bool: return is_origin_of( type_, "Optional", @@ -199,7 +212,7 @@ def is_optional_union(type_: Any) -> bool: return is_optional(type_) and NoneType in typing_get_args(type_) -def is_union(type_): +def is_union(type_: Any) -> bool: return is_origin_of(type_, "Union") diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index 74cbebb4da..76ee464fad 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -1,13 +1,19 @@ from __future__ import annotations +from decimal import Decimal from typing import List +from typing import Optional from typing import Set from typing import TypeVar +from typing import Union from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import Integer +from sqlalchemy import Numeric +from sqlalchemy import Table from sqlalchemy.orm import attribute_mapped_collection +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import MappedCollection @@ -16,7 +22,8 @@ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true -from .test_typed_mapping import MappedColumnTest # noqa +from sqlalchemy.util import compat +from .test_typed_mapping import MappedColumnTest as _MappedColumnTest from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest """runs the annotation-sensitive tests from test_typed_mappings while @@ -28,6 +35,79 @@ having ``from __future__ import annotations`` in effect. _R = TypeVar("_R") +class MappedColumnTest(_MappedColumnTest): + def test_unions(self): + our_type = Numeric(10, 2) + + class Base(DeclarativeBase): + type_annotation_map = {Union[float, Decimal]: our_type} + + class User(Base): + __tablename__ = "users" + __table__: Table + + id: Mapped[int] = mapped_column(primary_key=True) + + data: Mapped[Union[float, Decimal]] = mapped_column() + reverse_data: Mapped[Union[Decimal, float]] = mapped_column() + + optional_data: Mapped[ + Optional[Union[float, Decimal]] + ] = mapped_column() + + # use Optional directly + reverse_optional_data: Mapped[ + Optional[Union[Decimal, float]] + ] = mapped_column() + + # use Union with None, same as Optional but presents differently + # (Optional object with __origin__ Union vs. Union) + reverse_u_optional_data: Mapped[ + Union[Decimal, float, None] + ] = mapped_column() + + float_data: Mapped[float] = mapped_column() + decimal_data: Mapped[Decimal] = mapped_column() + + if compat.py310: + pep604_data: Mapped[float | Decimal] = mapped_column() + pep604_reverse: Mapped[Decimal | float] = mapped_column() + pep604_optional: Mapped[ + Decimal | float | None + ] = mapped_column() + pep604_data_fwd: Mapped["float | Decimal"] = mapped_column() + pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column() + pep604_optional_fwd: Mapped[ + "Decimal | float | None" + ] = mapped_column() + + is_(User.__table__.c.data.type, our_type) + is_false(User.__table__.c.data.nullable) + is_(User.__table__.c.reverse_data.type, our_type) + is_(User.__table__.c.optional_data.type, our_type) + is_true(User.__table__.c.optional_data.nullable) + + is_(User.__table__.c.reverse_optional_data.type, our_type) + is_(User.__table__.c.reverse_u_optional_data.type, our_type) + is_true(User.__table__.c.reverse_optional_data.nullable) + is_true(User.__table__.c.reverse_u_optional_data.nullable) + + is_(User.__table__.c.float_data.type, our_type) + is_(User.__table__.c.decimal_data.type, our_type) + + if compat.py310: + for suffix in ("", "_fwd"): + data_col = User.__table__.c[f"pep604_data{suffix}"] + reverse_col = User.__table__.c[f"pep604_reverse{suffix}"] + optional_col = User.__table__.c[f"pep604_optional{suffix}"] + is_(data_col.type, our_type) + is_false(data_col.nullable) + is_(reverse_col.type, our_type) + is_false(reverse_col.nullable) + is_(optional_col.type, our_type) + is_true(optional_col.nullable) + + class MappedOneArg(MappedCollection[str, _R]): pass -- 2.47.2