From: Mike Bayer Date: Wed, 19 Oct 2022 01:01:05 +0000 (-0400) Subject: de-optionalize union types to support Optional for m2o X-Git-Tag: rel_2_0_0b2~7^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=257002227b811c85c7887236321d9965455889bc;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git de-optionalize union types to support Optional for m2o Fixed bug in new ORM typed declarative mappings where we did not include the ability to use ``Optional[]`` in the type annotation for a many-to-one relationship, even though this is common. Fixes: #8668 Change-Id: Idaf0846e49cc12095394b99ad6fe678735cf9242 --- diff --git a/doc/build/changelog/unreleased_20/8668.rst b/doc/build/changelog/unreleased_20/8668.rst new file mode 100644 index 0000000000..3dab4663fd --- /dev/null +++ b/doc/build/changelog/unreleased_20/8668.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, orm + :tickets: 8668 + + Fixed bug in new ORM typed declarative mappings where we did not include + the ability to use ``Optional[]`` in the type annotation for a many-to-one + relationship, even though this is common. diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 86a0f82c56..81c26d3722 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -86,6 +86,7 @@ from ..sql.util import ClauseAdapter from ..sql.util import join_condition from ..sql.util import selectables_overlap from ..sql.util import visit_binary_product +from ..util.typing import de_optionalize_union_types from ..util.typing import Literal if typing.TYPE_CHECKING: @@ -1742,18 +1743,21 @@ class RelationshipProperty( self.lazy = "dynamic" self.strategy_key = (("lazy", self.lazy),) - if hasattr(argument, "__origin__"): + argument = de_optionalize_union_types(argument) - collection_class = argument.__origin__ # type: ignore - if issubclass(collection_class, abc.Collection): + if hasattr(argument, "__origin__"): + arg_origin = argument.__origin__ # type: ignore + if isinstance(arg_origin, type) and issubclass( + arg_origin, abc.Collection + ): if self.collection_class is None: - self.collection_class = collection_class + self.collection_class = arg_origin elif not is_write_only and not is_dynamic: self.uselist = False if argument.__args__: # type: ignore - if issubclass( - argument.__origin__, typing.Mapping # type: ignore + if isinstance(arg_origin, type) and issubclass( + arg_origin, typing.Mapping # type: ignore ): type_arg = argument.__args__[-1] # type: ignore else: diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index ae1f9b35e1..bff9482ec5 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -321,7 +321,7 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): id: Mapped[intpk] = mapped_column(init=False) email_address: Mapped[str] user_id: Mapped[user_fk] = mapped_column(init=False) - user: Mapped["User"] = relationship( + user: Mapped[Optional["User"]] = relationship( back_populates="addresses", default=None ) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index b64694fc59..ae8e9d746b 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1236,7 +1236,8 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): select(A).join(A.bs), "SELECT a.id FROM a JOIN b ON a.id = b.a_id" ) - def test_basic_bidirectional(self, decl_base): + @testing.combinations(True, False, argnames="optional_on_m2o") + def test_basic_bidirectional(self, decl_base, optional_on_m2o): class A(decl_base): __tablename__ = "a" @@ -1251,9 +1252,14 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): id: Mapped[int] = mapped_column(Integer, primary_key=True) a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) - a: Mapped["A"] = relationship( - back_populates="bs", primaryjoin=a_id == A.id - ) + if optional_on_m2o: + a: Mapped[Optional["A"]] = relationship( + back_populates="bs", primaryjoin=a_id == A.id + ) + else: + a: Mapped["A"] = relationship( + back_populates="bs", primaryjoin=a_id == A.id + ) a1 = A(data="data") b1 = B()