]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
de-optionalize union types to support Optional for m2o
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 19 Oct 2022 01:01:05 +0000 (21:01 -0400)
committermike bayer <mike_mp@zzzcomputing.com>
Wed, 19 Oct 2022 12:01:21 +0000 (12:01 +0000)
Fixed bug in new ORM typed declarative mappings where we did not include
the ability to use ``Optional[]`` in the type annotation for a many-to-one
relationship, even though this is common.

Fixes: #8668
Change-Id: Idaf0846e49cc12095394b99ad6fe678735cf9242

doc/build/changelog/unreleased_20/8668.rst [new file with mode: 0644]
lib/sqlalchemy/orm/relationships.py
test/orm/declarative/test_dc_transforms.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/8668.rst b/doc/build/changelog/unreleased_20/8668.rst
new file mode 100644 (file)
index 0000000..3dab466
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 8668
+
+    Fixed bug in new ORM typed declarative mappings where we did not include
+    the ability to use ``Optional[]`` in the type annotation for a many-to-one
+    relationship, even though this is common.
index 86a0f82c5660007aa4883e07fecc9ece188c87ec..81c26d37224ce1399a38892182266ca6f5ac88d5 100644 (file)
@@ -86,6 +86,7 @@ from ..sql.util import ClauseAdapter
 from ..sql.util import join_condition
 from ..sql.util import selectables_overlap
 from ..sql.util import visit_binary_product
+from ..util.typing import de_optionalize_union_types
 from ..util.typing import Literal
 
 if typing.TYPE_CHECKING:
@@ -1742,18 +1743,21 @@ class RelationshipProperty(
             self.lazy = "dynamic"
             self.strategy_key = (("lazy", self.lazy),)
 
-        if hasattr(argument, "__origin__"):
+        argument = de_optionalize_union_types(argument)
 
-            collection_class = argument.__origin__  # type: ignore
-            if issubclass(collection_class, abc.Collection):
+        if hasattr(argument, "__origin__"):
+            arg_origin = argument.__origin__  # type: ignore
+            if isinstance(arg_origin, type) and issubclass(
+                arg_origin, abc.Collection
+            ):
                 if self.collection_class is None:
-                    self.collection_class = collection_class
+                    self.collection_class = arg_origin
             elif not is_write_only and not is_dynamic:
                 self.uselist = False
 
             if argument.__args__:  # type: ignore
-                if issubclass(
-                    argument.__origin__, typing.Mapping  # type: ignore
+                if isinstance(arg_origin, type) and issubclass(
+                    arg_origin, typing.Mapping  # type: ignore
                 ):
                     type_arg = argument.__args__[-1]  # type: ignore
                 else:
index ae1f9b35e11426a69fff199515750f433b2330d2..bff9482ec5d542f87389fd9604103276909f6f7c 100644 (file)
@@ -321,7 +321,7 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
             id: Mapped[intpk] = mapped_column(init=False)
             email_address: Mapped[str]
             user_id: Mapped[user_fk] = mapped_column(init=False)
-            user: Mapped["User"] = relationship(
+            user: Mapped[Optional["User"]] = relationship(
                 back_populates="addresses", default=None
             )
 
index b64694fc597e884d2419dc29722459d668bd8760..ae8e9d746b9a13a2354acb973d46c3fe56b14ec3 100644 (file)
@@ -1236,7 +1236,8 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             select(A).join(A.bs), "SELECT a.id FROM a JOIN b ON a.id = b.a_id"
         )
 
-    def test_basic_bidirectional(self, decl_base):
+    @testing.combinations(True, False, argnames="optional_on_m2o")
+    def test_basic_bidirectional(self, decl_base, optional_on_m2o):
         class A(decl_base):
             __tablename__ = "a"
 
@@ -1251,9 +1252,14 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             id: Mapped[int] = mapped_column(Integer, primary_key=True)
             a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
 
-            a: Mapped["A"] = relationship(
-                back_populates="bs", primaryjoin=a_id == A.id
-            )
+            if optional_on_m2o:
+                a: Mapped[Optional["A"]] = relationship(
+                    back_populates="bs", primaryjoin=a_id == A.id
+                )
+            else:
+                a: Mapped["A"] = relationship(
+                    back_populates="bs", primaryjoin=a_id == A.id
+                )
 
         a1 = A(data="data")
         b1 = B()