]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
derive optional for nullable from interior of pep-593 types
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 29 Jan 2023 15:10:30 +0000 (10:10 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 30 Jan 2023 03:19:29 +0000 (22:19 -0500)
Improved the ruleset used to interpret :pep:`593` ``Annotated`` types when
used with Annotated Declarative mapping, the inner type will be checked for
"Optional" in all cases which will be added to the criteria by which the
column is set as "nullable" or not; if the type within the ``Annotated``
container is optional (or unioned with ``None``), the column will be
considered nullable if there are no explicit
:paramref:`_orm.mapped_column.nullable` parameters overriding it.

Fixes: #9177
Change-Id: I4b1240da198e35b93006fd90f6cb259c9d2cbf30

doc/build/changelog/unreleased_20/9177.rst [new file with mode: 0644]
lib/sqlalchemy/orm/properties.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/9177.rst b/doc/build/changelog/unreleased_20/9177.rst
new file mode 100644 (file)
index 0000000..86dc1ba
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9177
+
+    Improved the ruleset used to interpret :pep:`593` ``Annotated`` types when
+    used with Annotated Declarative mapping, the inner type will be checked for
+    "Optional" in all cases which will be added to the criteria by which the
+    column is set as "nullable" or not; if the type within the ``Annotated``
+    container is optional (or unioned with ``None``), the column will be
+    considered nullable if there are no explicit
+    :paramref:`_orm.mapped_column.nullable` parameters overriding it.
index 9feb72e40f57586e685312dce93317421f0a06fe..c67c22942399b822a5db5499265ad492bed961a1 100644 (file)
@@ -730,14 +730,23 @@ class MappedColumn(
         our_type = de_optionalize_union_types(argument)
 
         use_args_from = None
+
         if is_pep593(our_type):
             our_type_is_pep593 = True
-            for elem in typing_get_args(our_type):
+            pep_593_components = typing_get_args(our_type)
+            raw_pep_593_type = pep_593_components[0]
+            if is_optional_union(raw_pep_593_type):
+                nullable = True
+                if not self._has_nullable:
+                    self.column.nullable = nullable
+                raw_pep_593_type = de_optionalize_union_types(raw_pep_593_type)
+            for elem in pep_593_components[1:]:
                 if isinstance(elem, MappedColumn):
                     use_args_from = elem
                     break
         else:
             our_type_is_pep593 = False
+            raw_pep_593_type = None
 
         if use_args_from is not None:
             if (
@@ -752,9 +761,9 @@ class MappedColumn(
             new_sqltype = None
 
             if our_type_is_pep593:
-                checks = (our_type,) + typing_get_args(our_type)
+                checks = [our_type, raw_pep_593_type]
             else:
-                checks = (our_type,)
+                checks = [our_type]
 
             for check_type in checks:
 
index ae10f7d8e4200395efe7f26a4b5bc62cde7fbac4..a83b02cd028a7bf4639a54738bee74fad2238ab0 100644 (file)
@@ -457,6 +457,23 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             args = ()
 
+        global anno_str, anno_str_optional, anno_str_mc
+        global anno_str_optional_mc, anno_str_mc_nullable
+        global anno_str_optional_mc_notnull
+        anno_str = Annotated[str, 50]
+        anno_str_optional = Annotated[Optional[str], 30]
+
+        anno_str_mc = Annotated[str, mapped_column()]
+        anno_str_optional_mc = Annotated[Optional[str], mapped_column()]
+        anno_str_mc_nullable = Annotated[str, mapped_column(nullable=True)]
+        anno_str_optional_mc_notnull = Annotated[
+            Optional[str], mapped_column(nullable=False)
+        ]
+
+        decl_base.registry.update_type_annotation_map(
+            {anno_str: String(50), anno_str_optional: String(30)}
+        )
+
         class User(decl_base):
             __tablename__ = "users"
 
@@ -473,6 +490,36 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 *args, nullable=True
             )
 
+            # test #9177 cases
+            anno_1a: Mapped[anno_str] = mapped_column(*args)
+            anno_1b: Mapped[anno_str] = mapped_column(*args, nullable=True)
+
+            anno_2a: Mapped[anno_str_optional] = mapped_column(*args)
+            anno_2b: Mapped[anno_str_optional] = mapped_column(
+                *args, nullable=False
+            )
+
+            anno_3a: Mapped[anno_str_mc] = mapped_column(*args)
+            anno_3b: Mapped[anno_str_mc] = mapped_column(*args, nullable=True)
+            anno_3c: Mapped[Optional[anno_str_mc]] = mapped_column(*args)
+
+            anno_4a: Mapped[anno_str_optional_mc] = mapped_column(*args)
+            anno_4b: Mapped[anno_str_optional_mc] = mapped_column(
+                *args, nullable=False
+            )
+
+            anno_5a: Mapped[anno_str_mc_nullable] = mapped_column(*args)
+            anno_5b: Mapped[anno_str_mc_nullable] = mapped_column(
+                *args, nullable=False
+            )
+
+            anno_6a: Mapped[anno_str_optional_mc_notnull] = mapped_column(
+                *args
+            )
+            anno_6b: Mapped[anno_str_optional_mc_notnull] = mapped_column(
+                *args, nullable=True
+            )
+
         is_false(User.__table__.c.lnnl_rndf.nullable)
         is_false(User.__table__.c.lnnl_rnnl.nullable)
         is_true(User.__table__.c.lnnl_rnl.nullable)
@@ -481,6 +528,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_false(User.__table__.c.lnl_rnnl.nullable)
         is_true(User.__table__.c.lnl_rnl.nullable)
 
+        is_false(User.__table__.c.anno_1a.nullable)
+        is_true(User.__table__.c.anno_1b.nullable)
+        is_true(User.__table__.c.anno_2a.nullable)
+        is_false(User.__table__.c.anno_2b.nullable)
+        is_false(User.__table__.c.anno_3a.nullable)
+        is_true(User.__table__.c.anno_3b.nullable)
+        is_true(User.__table__.c.anno_3c.nullable)
+        is_true(User.__table__.c.anno_4a.nullable)
+        is_false(User.__table__.c.anno_4b.nullable)
+        is_true(User.__table__.c.anno_5a.nullable)
+        is_false(User.__table__.c.anno_5b.nullable)
+        is_false(User.__table__.c.anno_6a.nullable)
+        is_true(User.__table__.c.anno_6b.nullable)
+
         # test #8410
         is_false(User.__table__.c.lnnl_rndf._copy().nullable)
         is_false(User.__table__.c.lnnl_rnnl._copy().nullable)
index 8838afd0ff3c7e8e2b997dc99fe7b1096956bd3b..9a2faf22a46e1e2f3d44399d3abc921df02e2f11 100644 (file)
@@ -448,6 +448,23 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             args = ()
 
+        # anno only: global anno_str, anno_str_optional, anno_str_mc
+        # anno only: global anno_str_optional_mc, anno_str_mc_nullable
+        # anno only: global anno_str_optional_mc_notnull
+        anno_str = Annotated[str, 50]
+        anno_str_optional = Annotated[Optional[str], 30]
+
+        anno_str_mc = Annotated[str, mapped_column()]
+        anno_str_optional_mc = Annotated[Optional[str], mapped_column()]
+        anno_str_mc_nullable = Annotated[str, mapped_column(nullable=True)]
+        anno_str_optional_mc_notnull = Annotated[
+            Optional[str], mapped_column(nullable=False)
+        ]
+
+        decl_base.registry.update_type_annotation_map(
+            {anno_str: String(50), anno_str_optional: String(30)}
+        )
+
         class User(decl_base):
             __tablename__ = "users"
 
@@ -464,6 +481,36 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 *args, nullable=True
             )
 
+            # test #9177 cases
+            anno_1a: Mapped[anno_str] = mapped_column(*args)
+            anno_1b: Mapped[anno_str] = mapped_column(*args, nullable=True)
+
+            anno_2a: Mapped[anno_str_optional] = mapped_column(*args)
+            anno_2b: Mapped[anno_str_optional] = mapped_column(
+                *args, nullable=False
+            )
+
+            anno_3a: Mapped[anno_str_mc] = mapped_column(*args)
+            anno_3b: Mapped[anno_str_mc] = mapped_column(*args, nullable=True)
+            anno_3c: Mapped[Optional[anno_str_mc]] = mapped_column(*args)
+
+            anno_4a: Mapped[anno_str_optional_mc] = mapped_column(*args)
+            anno_4b: Mapped[anno_str_optional_mc] = mapped_column(
+                *args, nullable=False
+            )
+
+            anno_5a: Mapped[anno_str_mc_nullable] = mapped_column(*args)
+            anno_5b: Mapped[anno_str_mc_nullable] = mapped_column(
+                *args, nullable=False
+            )
+
+            anno_6a: Mapped[anno_str_optional_mc_notnull] = mapped_column(
+                *args
+            )
+            anno_6b: Mapped[anno_str_optional_mc_notnull] = mapped_column(
+                *args, nullable=True
+            )
+
         is_false(User.__table__.c.lnnl_rndf.nullable)
         is_false(User.__table__.c.lnnl_rnnl.nullable)
         is_true(User.__table__.c.lnnl_rnl.nullable)
@@ -472,6 +519,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_false(User.__table__.c.lnl_rnnl.nullable)
         is_true(User.__table__.c.lnl_rnl.nullable)
 
+        is_false(User.__table__.c.anno_1a.nullable)
+        is_true(User.__table__.c.anno_1b.nullable)
+        is_true(User.__table__.c.anno_2a.nullable)
+        is_false(User.__table__.c.anno_2b.nullable)
+        is_false(User.__table__.c.anno_3a.nullable)
+        is_true(User.__table__.c.anno_3b.nullable)
+        is_true(User.__table__.c.anno_3c.nullable)
+        is_true(User.__table__.c.anno_4a.nullable)
+        is_false(User.__table__.c.anno_4b.nullable)
+        is_true(User.__table__.c.anno_5a.nullable)
+        is_false(User.__table__.c.anno_5b.nullable)
+        is_false(User.__table__.c.anno_6a.nullable)
+        is_true(User.__table__.c.anno_6b.nullable)
+
         # test #8410
         is_false(User.__table__.c.lnnl_rndf._copy().nullable)
         is_false(User.__table__.c.lnnl_rnnl._copy().nullable)