From: Peter Schutt Date: Thu, 1 Sep 2022 23:11:40 +0000 (-0400) Subject: Detection of PEP 604 union syntax. X-Git-Tag: rel_2_0_0b1~78^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c3cfee5b00a40790c18d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Detection of PEP 604 union syntax. ### Description Fixes #8478 Handle `UnionType` as arguments to `Mapped`, e.g., `Mapped[str | None]`: - adds `utils.typing.is_optional_union()` used to detect if a column should be nullable. - adds `"UnionType"` to `utils.typing.is_optional()` names. - uses `get_origin()` in `utils.typing.is_origin_of()` as `UnionType` has no `__origin__` attribute. - tests with runtime type and postponed annotations and guard the tests running with `compat.py310`. ### Checklist This pull request is: - [ ] A documentation / typographical error fix - Good to go, no issue or tests are needed - [x] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. **Have a nice day!** Closes: #8479 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8479 Pull-request-sha: 12417654822272c5847e684c53677f665553ef0e Change-Id: Ib3248043dd4a97324ac592c048385006536b2d49 --- diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 6213cfef84..7d71756780 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -52,8 +52,8 @@ from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref +from ..util.typing import is_optional_union from ..util.typing import is_pep593 -from ..util.typing import NoneType from ..util.typing import Self from ..util.typing import typing_get_args @@ -652,17 +652,15 @@ class MappedColumn( ) -> None: sqltype = self.column.type - nullable = False + if is_fwd_ref(argument): + argument = de_stringify_annotation(cls, argument) - if hasattr(argument, "__origin__"): - nullable = NoneType in argument.__args__ # type: ignore + nullable = is_optional_union(argument) if not self._has_nullable: self.column.nullable = nullable our_type = de_optionalize_union_types(argument) - if is_fwd_ref(our_type): - our_type = de_stringify_annotation(cls, our_type) use_args_from = None if is_pep593(our_type): diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 45fe63765b..85c1bae72b 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -169,7 +169,7 @@ def make_union_type(*types: _AnnotationScanType) -> Type[Any]: def expand_unions( type_: Type[Any], include_union: bool = False, discard_none: bool = False ) -> Tuple[Type[Any], ...]: - """Return a type as as a tuple of individual types, expanding for + """Return a type as a tuple of individual types, expanding for ``Union`` types.""" if is_union(type_): @@ -191,9 +191,14 @@ def is_optional(type_): type_, "Optional", "Union", + "UnionType", ) +def is_optional_union(type_: Any) -> bool: + return is_optional(type_) and NoneType in typing_get_args(type_) + + def is_union(type_): return is_origin_of(type_, "Union") @@ -204,7 +209,7 @@ def is_origin_of( """return True if the given type has an __origin__ with the given name and optional module.""" - origin = getattr(type_, "__origin__", None) + origin = typing_get_origin(type_) if origin is None: return False diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 98736cf025..16cfee3407 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -52,6 +52,7 @@ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true from sqlalchemy.testing.fixtures import fixture_session +from sqlalchemy.util import compat from sqlalchemy.util.typing import Annotated @@ -858,6 +859,7 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): data: Mapped[Union[float, Decimal]] = mapped_column() reverse_data: Mapped[Union[Decimal, float]] = mapped_column() + optional_data: Mapped[ Optional[Union[float, Decimal]] ] = mapped_column() @@ -872,9 +874,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): reverse_u_optional_data: Mapped[ Union[Decimal, float, None] ] = mapped_column() + float_data: Mapped[float] = mapped_column() decimal_data: Mapped[Decimal] = mapped_column() + if compat.py310: + pep604_data: Mapped[float | Decimal] = mapped_column() + pep604_reverse: Mapped[Decimal | float] = mapped_column() + pep604_optional: Mapped[ + Decimal | float | None + ] = mapped_column() + pep604_data_fwd: Mapped["float | Decimal"] = mapped_column() + pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column() + pep604_optional_fwd: Mapped[ + "Decimal | float | None" + ] = mapped_column() + is_(User.__table__.c.data.type, our_type) is_false(User.__table__.c.data.nullable) is_(User.__table__.c.reverse_data.type, our_type) @@ -889,6 +904,18 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_(User.__table__.c.float_data.type, our_type) is_(User.__table__.c.decimal_data.type, our_type) + if compat.py310: + for suffix in ("", "_fwd"): + data_col = User.__table__.c[f"pep604_data{suffix}"] + reverse_col = User.__table__.c[f"pep604_reverse{suffix}"] + optional_col = User.__table__.c[f"pep604_optional{suffix}"] + is_(data_col.type, our_type) + is_false(data_col.nullable) + is_(reverse_col.type, our_type) + is_false(reverse_col.nullable) + is_(optional_col.type, our_type) + is_true(optional_col.nullable) + def test_missing_mapped_lhs(self, decl_base): with expect_raises_message( ArgumentError,