]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix optionalized forms for dict[]
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 21 Nov 2022 14:46:43 +0000 (09:46 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Nov 2022 14:49:02 +0000 (09:49 -0500)
Fixed a suite of issues involving :class:`.Mapped` use with dictionary
types, such as ``Mapped[dict[str, str] | None]``, would not be correctly
interpreted in Declarative ORM mappings. Support to correctly
"de-optionalize" this type including for lookup in type_annotation_map has
been fixed.

Fixes: #8777
Change-Id: Iaba90791dea341d00eaff788d40b0a4e48dab02e

doc/build/changelog/unreleased_20/8777.rst [new file with mode: 0644]
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/8777.rst b/doc/build/changelog/unreleased_20/8777.rst
new file mode 100644 (file)
index 0000000..b212246
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 8777
+
+    Fixed a suite of issues involving :class:`.Mapped` use with dictionary
+    types, such as ``Mapped[dict[str, str] | None]``, would not be correctly
+    interpreted in Declarative ORM mappings. Support to correctly
+    "de-optionalize" this type including for lookup in type_annotation_map has
+    been fixed.
index c1da267f4dca1ceb8f989daf02856961141b64e3..520d61c4e686b100b2678bb2aa65350b96d0b93f 100644 (file)
@@ -437,7 +437,7 @@ class ColumnProperty(
             try:
                 return ce.info  # type: ignore
             except AttributeError:
-                return self.prop.info
+                return self.prop.info  # type: ignore
 
         def _memoized_attr_expressions(self) -> Sequence[NamedColumn[Any]]:
             """The full sequence of columns referenced by this
@@ -686,10 +686,10 @@ class MappedColumn(
     ) -> None:
         sqltype = self.column.type
 
-        if is_fwd_ref(argument):
+        if is_fwd_ref(argument, check_generic=True):
             assert originating_module is not None
             argument = de_stringify_annotation(
-                cls, argument, originating_module
+                cls, argument, originating_module, include_generic=True
             )
 
         if is_union(argument):
index e5bdbaa4f3700c3073c7d400cc55f46c92c78581..06f2d6d1d02b07a085836642af5d0a8c9cce2764 100644 (file)
@@ -2097,13 +2097,30 @@ def _cleanup_mapped_str_annotation(
             stack.append(g2)
             break
 
-    # stack: ['Mapped', 'List', 'Address']
-    if not re.match(r"""^["'].*["']$""", stack[-1]):
+    # stacks we want to rewrite, that is, quote the last entry which
+    # we think is a relationship class name:
+    #
+    #   ['Mapped', 'List', 'Address']
+    #   ['Mapped', 'A']
+    #
+    # stacks we dont want to rewrite, which are generally MappedColumn
+    # use cases:
+    #
+    # ['Mapped', "'Optional[Dict[str, str]]'"]
+    # ['Mapped', 'dict[str, str] | None']
+
+    if (
+        # avoid already quoted symbols such as
+        # ['Mapped', "'Optional[Dict[str, str]]'"]
+        not re.match(r"""^["'].*["']$""", stack[-1])
+        # avoid further generics like Dict[] such as
+        # ['Mapped', 'dict[str, str] | None']
+        and not re.match(r".*\[.*\]", stack[-1])
+    ):
         stripchars = "\"' "
         stack[-1] = ", ".join(
             f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",")
         )
-        # stack: ['Mapped', 'List', '"Address"']
 
         annotation = "[".join(stack) + ("]" * (len(stack) - 1))
 
index 9eb761eff0629016da1fe7608194d5d4c870da9b..f87ee845b1ab37f6ae512c2e6495dc25f939da45 100644 (file)
@@ -70,12 +70,37 @@ NoneFwd = ForwardRef("None")
 typing_get_args = get_args
 typing_get_origin = get_origin
 
+
 # copied from TypeShed, required in order to implement
 # MutableMapping.update()
 
 _AnnotationScanType = Union[Type[Any], str, ForwardRef]
 
 
+class ArgsTypeProcotol(Protocol):
+    """protocol for types that have ``__args__``
+
+    there's no public interface for this AFAIK
+
+    """
+
+    __args__: Tuple[_AnnotationScanType, ...]
+
+
+class GenericProtocol(Protocol[_T]):
+    """protocol for generic types.
+
+    this since Python.typing _GenericAlias is private
+
+    """
+
+    __args__: Tuple[_AnnotationScanType, ...]
+    __origin__: Type[_T]
+
+    def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]:
+        ...
+
+
 class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
     def keys(self) -> Iterable[_KT]:
         ...
@@ -93,6 +118,7 @@ def de_stringify_annotation(
     annotation: _AnnotationScanType,
     originating_module: str,
     str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
+    include_generic: bool = False,
 ) -> Type[Any]:
     """Resolve annotations that may be string based into real objects.
 
@@ -119,6 +145,20 @@ def de_stringify_annotation(
             annotation = str_cleanup_fn(annotation, originating_module)
 
         annotation = eval_expression(annotation, originating_module)
+
+    if include_generic and is_generic(annotation):
+        elements = tuple(
+            de_stringify_annotation(
+                cls,
+                elem,
+                originating_module,
+                str_cleanup_fn=str_cleanup_fn,
+                include_generic=include_generic,
+            )
+            for elem in annotation.__args__
+        )
+
+        return annotation.copy_with(elements)
     return annotation  # type: ignore
 
 
@@ -172,7 +212,7 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
 
 def de_stringify_union_elements(
     cls: Type[Any],
-    annotation: _AnnotationScanType,
+    annotation: ArgsTypeProcotol,
     originating_module: str,
     str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
 ) -> Type[Any]:
@@ -181,7 +221,7 @@ def de_stringify_union_elements(
             de_stringify_annotation(
                 cls, anno, originating_module, str_cleanup_fn
             )
-            for anno in annotation.__args__  # type: ignore
+            for anno in annotation.__args__
         ]
     )
 
@@ -190,8 +230,19 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
     return type_ is not None and typing_get_origin(type_) is Annotated
 
 
-def is_fwd_ref(type_: _AnnotationScanType) -> bool:
-    return isinstance(type_, ForwardRef)
+def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
+    return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
+
+
+def is_fwd_ref(
+    type_: _AnnotationScanType, check_generic: bool = False
+) -> bool:
+    if isinstance(type_, ForwardRef):
+        return True
+    elif check_generic and is_generic(type_):
+        return any(is_fwd_ref(arg, True) for arg in type_.__args__)
+    else:
+        return False
 
 
 @overload
@@ -218,11 +269,12 @@ def de_optionalize_union_types(
     to not include the ``NoneType``.
 
     """
+
     if is_fwd_ref(type_):
         return de_optionalize_fwd_ref_union_types(cast(ForwardRef, type_))
 
     elif is_optional(type_):
-        typ = set(type_.__args__)  # type: ignore
+        typ = set(type_.__args__)
 
         typ.discard(NoneType)
         typ.discard(NoneFwd)
@@ -287,14 +339,14 @@ def expand_unions(
             typ.discard(NoneType)
 
         if include_union:
-            return (type_,) + tuple(typ)
+            return (type_,) + tuple(typ)  # type: ignore
         else:
-            return tuple(typ)
+            return tuple(typ)  # type: ignore
     else:
         return (type_,)
 
 
-def is_optional(type_: Any) -> bool:
+def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
     return is_origin_of(
         type_,
         "Optional",
@@ -307,7 +359,7 @@ def is_optional_union(type_: Any) -> bool:
     return is_optional(type_) and NoneType in typing_get_args(type_)
 
 
-def is_union(type_: Any) -> bool:
+def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
     return is_origin_of(type_, "Union")
 
 
index d66b08f4e0d64715dc3a59b18ba301c4e71029dd..0f00c2fe469b0be2f88398f4033cd20b972547e0 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 from decimal import Decimal
+from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Set
@@ -13,6 +14,7 @@ from sqlalchemy import Column
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
+from sqlalchemy import JSON
 from sqlalchemy import Numeric
 from sqlalchemy import select
 from sqlalchemy import String
@@ -209,6 +211,62 @@ class MappedColumnTest(_MappedColumnTest):
                 is_(optional_col.type, our_type)
                 is_true(optional_col.nullable)
 
+    @testing.combinations(
+        ("not_optional",),
+        ("optional",),
+        ("optional_fwd_ref",),
+        ("union_none",),
+        ("pep604", testing.requires.python310),
+        ("pep604_fwd_ref", testing.requires.python310),
+        argnames="optional_on_json",
+    )
+    @testing.combinations(
+        "include_mc_type", "derive_from_anno", argnames="include_mc_type"
+    )
+    def test_optional_styles_nested_brackets(
+        self, optional_on_json, include_mc_type
+    ):
+        class Base(DeclarativeBase):
+            if testing.requires.python310.enabled:
+                type_annotation_map = {
+                    Dict[str, str]: JSON,
+                    dict[str, str]: JSON,
+                }
+            else:
+                type_annotation_map = {
+                    Dict[str, str]: JSON,
+                }
+
+        if include_mc_type == "include_mc_type":
+            mc = mapped_column(JSON)
+        else:
+            mc = mapped_column()
+
+        class A(Base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column()
+
+            if optional_on_json == "not_optional":
+                json: Mapped[Dict[str, str]] = mapped_column()  # type: ignore
+            elif optional_on_json == "optional":
+                json: Mapped[Optional[Dict[str, str]]] = mc
+            elif optional_on_json == "optional_fwd_ref":
+                json: Mapped["Optional[Dict[str, str]]"] = mc
+            elif optional_on_json == "union_none":
+                json: Mapped[Union[Dict[str, str], None]] = mc
+            elif optional_on_json == "pep604":
+                json: Mapped[dict[str, str] | None] = mc
+            elif optional_on_json == "pep604_fwd_ref":
+                json: Mapped["dict[str, str] | None"] = mc
+
+        is_(A.__table__.c.json.type._type_affinity, JSON)
+        if optional_on_json == "not_optional":
+            is_false(A.__table__.c.json.nullable)
+        else:
+            is_true(A.__table__.c.json.nullable)
+
     def test_typ_not_in_cls_namespace(self, decl_base):
         """test #8742.
 
index e3f5e59f4d703179aad5844085958f461e2ab27e..527954e16e78f17b1539bb93c6fd6a859e2c8eb4 100644 (file)
@@ -928,6 +928,62 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 is_(optional_col.type, our_type)
                 is_true(optional_col.nullable)
 
+    @testing.combinations(
+        ("not_optional",),
+        ("optional",),
+        ("optional_fwd_ref",),
+        ("union_none",),
+        ("pep604", testing.requires.python310),
+        ("pep604_fwd_ref", testing.requires.python310),
+        argnames="optional_on_json",
+    )
+    @testing.combinations(
+        "include_mc_type", "derive_from_anno", argnames="include_mc_type"
+    )
+    def test_optional_styles_nested_brackets(
+        self, optional_on_json, include_mc_type
+    ):
+        class Base(DeclarativeBase):
+            if testing.requires.python310.enabled:
+                type_annotation_map = {
+                    Dict[str, str]: JSON,
+                    dict[str, str]: JSON,
+                }
+            else:
+                type_annotation_map = {
+                    Dict[str, str]: JSON,
+                }
+
+        if include_mc_type == "include_mc_type":
+            mc = mapped_column(JSON)
+        else:
+            mc = mapped_column()
+
+        class A(Base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column()
+
+            if optional_on_json == "not_optional":
+                json: Mapped[Dict[str, str]] = mapped_column()  # type: ignore
+            elif optional_on_json == "optional":
+                json: Mapped[Optional[Dict[str, str]]] = mc
+            elif optional_on_json == "optional_fwd_ref":
+                json: Mapped["Optional[Dict[str, str]]"] = mc
+            elif optional_on_json == "union_none":
+                json: Mapped[Union[Dict[str, str], None]] = mc
+            elif optional_on_json == "pep604":
+                json: Mapped[dict[str, str] | None] = mc
+            elif optional_on_json == "pep604_fwd_ref":
+                json: Mapped["dict[str, str] | None"] = mc
+
+        is_(A.__table__.c.json.type._type_affinity, JSON)
+        if optional_on_json == "not_optional":
+            is_false(A.__table__.c.json.nullable)
+        else:
+            is_true(A.__table__.c.json.nullable)
+
     def test_missing_mapped_lhs(self, decl_base):
         with expect_annotation_syntax_error("User.name"):