]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
additional de-stringify pass for unions
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Sep 2022 17:19:08 +0000 (13:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Sep 2022 18:15:50 +0000 (14:15 -0400)
the change in c3cfee5b00a40790c18d took out
a pass for de-stringify that broke some un-tested cases
for Optional with future annotations mode.   Adding tests
for this revealed that this was a subset of
a more general case where Union is presented
with ForwardRefs inside of it matching up within the type
map, which wasn't working before either, fixed that as well with
an additional de-stringify for elements within the Union.

Fixes: #8478
Change-Id: I8804cf6c67f14d10804584e1cddd2cfaa2376654

lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations.py

index 7d71756780662eef61495372c0ef0c5ae61ce737..3d9fe578d8b73a5030b3d868e3dd043d647d2cdd 100644 (file)
@@ -51,9 +51,11 @@ from ..sql.schema import Column
 from ..sql.schema import SchemaConst
 from ..util.typing import de_optionalize_union_types
 from ..util.typing import de_stringify_annotation
+from ..util.typing import de_stringify_union_elements
 from ..util.typing import is_fwd_ref
 from ..util.typing import is_optional_union
 from ..util.typing import is_pep593
+from ..util.typing import is_union
 from ..util.typing import Self
 from ..util.typing import typing_get_args
 
@@ -655,6 +657,9 @@ class MappedColumn(
         if is_fwd_ref(argument):
             argument = de_stringify_annotation(cls, argument)
 
+        if is_union(argument):
+            argument = de_stringify_union_elements(cls, argument)
+
         nullable = is_optional_union(argument)
 
         if not self._has_nullable:
@@ -690,6 +695,7 @@ class MappedColumn(
                 checks = (our_type,)
 
             for check_type in checks:
+
                 if registry.type_annotation_map:
                     new_sqltype = registry.type_annotation_map.get(check_type)
                 if new_sqltype is None:
index 85c1bae72bc3f6d0dcfaeff8e1c19275af4d30fd..a0d59a6305b25f761ecc66b81da45ea8230a5414 100644 (file)
@@ -120,6 +120,19 @@ def de_stringify_annotation(
     return annotation  # type: ignore
 
 
+def de_stringify_union_elements(
+    cls: Type[Any],
+    annotation: _AnnotationScanType,
+    str_cleanup_fn: Optional[Callable[[str], str]] = None,
+) -> Type[Any]:
+    return make_union_type(
+        *[
+            de_stringify_annotation(cls, anno, str_cleanup_fn)
+            for anno in annotation.__args__  # type: ignore
+        ]
+    )
+
+
 def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
     return type_ is not None and typing_get_origin(type_) is Annotated
 
@@ -186,7 +199,7 @@ def expand_unions(
         return (type_,)
 
 
-def is_optional(type_):
+def is_optional(type_: Any) -> bool:
     return is_origin_of(
         type_,
         "Optional",
@@ -199,7 +212,7 @@ def is_optional_union(type_: Any) -> bool:
     return is_optional(type_) and NoneType in typing_get_args(type_)
 
 
-def is_union(type_):
+def is_union(type_: Any) -> bool:
     return is_origin_of(type_, "Union")
 
 
index 74cbebb4da7221d85e2ce22cb67949ff27a3e61a..76ee464fad04b676f7a7bea5f5e11d8479967828 100644 (file)
@@ -1,13 +1,19 @@
 from __future__ import annotations
 
+from decimal import Decimal
 from typing import List
+from typing import Optional
 from typing import Set
 from typing import TypeVar
+from typing import Union
 
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
+from sqlalchemy import Numeric
+from sqlalchemy import Table
 from sqlalchemy.orm import attribute_mapped_collection
+from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import MappedCollection
@@ -16,7 +22,8 @@ from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
-from .test_typed_mapping import MappedColumnTest  # noqa
+from sqlalchemy.util import compat
+from .test_typed_mapping import MappedColumnTest as _MappedColumnTest
 from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest
 
 """runs the annotation-sensitive tests from test_typed_mappings while
@@ -28,6 +35,79 @@ having ``from __future__ import annotations`` in effect.
 _R = TypeVar("_R")
 
 
+class MappedColumnTest(_MappedColumnTest):
+    def test_unions(self):
+        our_type = Numeric(10, 2)
+
+        class Base(DeclarativeBase):
+            type_annotation_map = {Union[float, Decimal]: our_type}
+
+        class User(Base):
+            __tablename__ = "users"
+            __table__: Table
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+            data: Mapped[Union[float, Decimal]] = mapped_column()
+            reverse_data: Mapped[Union[Decimal, float]] = mapped_column()
+
+            optional_data: Mapped[
+                Optional[Union[float, Decimal]]
+            ] = mapped_column()
+
+            # use Optional directly
+            reverse_optional_data: Mapped[
+                Optional[Union[Decimal, float]]
+            ] = mapped_column()
+
+            # use Union with None, same as Optional but presents differently
+            # (Optional object with __origin__ Union vs. Union)
+            reverse_u_optional_data: Mapped[
+                Union[Decimal, float, None]
+            ] = mapped_column()
+
+            float_data: Mapped[float] = mapped_column()
+            decimal_data: Mapped[Decimal] = mapped_column()
+
+            if compat.py310:
+                pep604_data: Mapped[float | Decimal] = mapped_column()
+                pep604_reverse: Mapped[Decimal | float] = mapped_column()
+                pep604_optional: Mapped[
+                    Decimal | float | None
+                ] = mapped_column()
+                pep604_data_fwd: Mapped["float | Decimal"] = mapped_column()
+                pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column()
+                pep604_optional_fwd: Mapped[
+                    "Decimal | float | None"
+                ] = mapped_column()
+
+        is_(User.__table__.c.data.type, our_type)
+        is_false(User.__table__.c.data.nullable)
+        is_(User.__table__.c.reverse_data.type, our_type)
+        is_(User.__table__.c.optional_data.type, our_type)
+        is_true(User.__table__.c.optional_data.nullable)
+
+        is_(User.__table__.c.reverse_optional_data.type, our_type)
+        is_(User.__table__.c.reverse_u_optional_data.type, our_type)
+        is_true(User.__table__.c.reverse_optional_data.nullable)
+        is_true(User.__table__.c.reverse_u_optional_data.nullable)
+
+        is_(User.__table__.c.float_data.type, our_type)
+        is_(User.__table__.c.decimal_data.type, our_type)
+
+        if compat.py310:
+            for suffix in ("", "_fwd"):
+                data_col = User.__table__.c[f"pep604_data{suffix}"]
+                reverse_col = User.__table__.c[f"pep604_reverse{suffix}"]
+                optional_col = User.__table__.c[f"pep604_optional{suffix}"]
+                is_(data_col.type, our_type)
+                is_false(data_col.nullable)
+                is_(reverse_col.type, our_type)
+                is_false(reverse_col.nullable)
+                is_(optional_col.type, our_type)
+                is_true(optional_col.nullable)
+
+
 class MappedOneArg(MappedCollection[str, _R]):
     pass