]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Detection of PEP 604 union syntax.
authorPeter Schutt <peter.github@proton.me>
Thu, 1 Sep 2022 23:11:40 +0000 (19:11 -0400)
committersqla-tester <sqla-tester@sqlalchemy.org>
Thu, 1 Sep 2022 23:11:40 +0000 (19:11 -0400)
### Description

Fixes #8478

Handle `UnionType` as arguments to `Mapped`, e.g., `Mapped[str | None]`:

- adds `utils.typing.is_optional_union()` used to detect if a column should be nullable.
- adds `"UnionType"` to `utils.typing.is_optional()` names.
- uses `get_origin()` in `utils.typing.is_origin_of()` as `UnionType` has no `__origin__` attribute.
- tests with runtime type and postponed annotations and guard the tests running with `compat.py310`.

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [x] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #8479
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8479
Pull-request-sha: 12417654822272c5847e684c53677f665553ef0e

Change-Id: Ib3248043dd4a97324ac592c048385006536b2d49

lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_typed_mapping.py

index 6213cfef8459f91b21035252370062fede974bb9..7d71756780662eef61495372c0ef0c5ae61ce737 100644 (file)
@@ -52,8 +52,8 @@ from ..sql.schema import SchemaConst
 from ..util.typing import de_optionalize_union_types
 from ..util.typing import de_stringify_annotation
 from ..util.typing import is_fwd_ref
+from ..util.typing import is_optional_union
 from ..util.typing import is_pep593
-from ..util.typing import NoneType
 from ..util.typing import Self
 from ..util.typing import typing_get_args
 
@@ -652,17 +652,15 @@ class MappedColumn(
     ) -> None:
         sqltype = self.column.type
 
-        nullable = False
+        if is_fwd_ref(argument):
+            argument = de_stringify_annotation(cls, argument)
 
-        if hasattr(argument, "__origin__"):
-            nullable = NoneType in argument.__args__  # type: ignore
+        nullable = is_optional_union(argument)
 
         if not self._has_nullable:
             self.column.nullable = nullable
 
         our_type = de_optionalize_union_types(argument)
-        if is_fwd_ref(our_type):
-            our_type = de_stringify_annotation(cls, our_type)
 
         use_args_from = None
         if is_pep593(our_type):
index 45fe63765b6ae1325ab0e145f8cf261f86352470..85c1bae72bc3f6d0dcfaeff8e1c19275af4d30fd 100644 (file)
@@ -169,7 +169,7 @@ def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
 def expand_unions(
     type_: Type[Any], include_union: bool = False, discard_none: bool = False
 ) -> Tuple[Type[Any], ...]:
-    """Return a type as as a tuple of individual types, expanding for
+    """Return a type as a tuple of individual types, expanding for
     ``Union`` types."""
 
     if is_union(type_):
@@ -191,9 +191,14 @@ def is_optional(type_):
         type_,
         "Optional",
         "Union",
+        "UnionType",
     )
 
 
+def is_optional_union(type_: Any) -> bool:
+    return is_optional(type_) and NoneType in typing_get_args(type_)
+
+
 def is_union(type_):
     return is_origin_of(type_, "Union")
 
@@ -204,7 +209,7 @@ def is_origin_of(
     """return True if the given type has an __origin__ with the given name
     and optional module."""
 
-    origin = getattr(type_, "__origin__", None)
+    origin = typing_get_origin(type_)
     if origin is None:
         return False
 
index 98736cf025cd838ae5c7b6ddb159af92617c0b29..16cfee3407021b37cbe212ed1eba58668b7cb089 100644 (file)
@@ -52,6 +52,7 @@ from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing.fixtures import fixture_session
+from sqlalchemy.util import compat
 from sqlalchemy.util.typing import Annotated
 
 
@@ -858,6 +859,7 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
             data: Mapped[Union[float, Decimal]] = mapped_column()
             reverse_data: Mapped[Union[Decimal, float]] = mapped_column()
+
             optional_data: Mapped[
                 Optional[Union[float, Decimal]]
             ] = mapped_column()
@@ -872,9 +874,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             reverse_u_optional_data: Mapped[
                 Union[Decimal, float, None]
             ] = mapped_column()
+
             float_data: Mapped[float] = mapped_column()
             decimal_data: Mapped[Decimal] = mapped_column()
 
+            if compat.py310:
+                pep604_data: Mapped[float | Decimal] = mapped_column()
+                pep604_reverse: Mapped[Decimal | float] = mapped_column()
+                pep604_optional: Mapped[
+                    Decimal | float | None
+                ] = mapped_column()
+                pep604_data_fwd: Mapped["float | Decimal"] = mapped_column()
+                pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column()
+                pep604_optional_fwd: Mapped[
+                    "Decimal | float | None"
+                ] = mapped_column()
+
         is_(User.__table__.c.data.type, our_type)
         is_false(User.__table__.c.data.nullable)
         is_(User.__table__.c.reverse_data.type, our_type)
@@ -889,6 +904,18 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_(User.__table__.c.float_data.type, our_type)
         is_(User.__table__.c.decimal_data.type, our_type)
 
+        if compat.py310:
+            for suffix in ("", "_fwd"):
+                data_col = User.__table__.c[f"pep604_data{suffix}"]
+                reverse_col = User.__table__.c[f"pep604_reverse{suffix}"]
+                optional_col = User.__table__.c[f"pep604_optional{suffix}"]
+                is_(data_col.type, our_type)
+                is_false(data_col.nullable)
+                is_(reverse_col.type, our_type)
+                is_false(reverse_col.nullable)
+                is_(optional_col.type, our_type)
+                is_true(optional_col.nullable)
+
     def test_missing_mapped_lhs(self, decl_base):
         with expect_raises_message(
             ArgumentError,