From: Mike Bayer Date: Mon, 21 Nov 2022 14:46:43 +0000 (-0500) Subject: fix optionalized forms for dict[] X-Git-Tag: rel_2_0_0b4~43^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=509ffeedefca1ad0ad8e29c6c3410d270fb3d2b9;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fix optionalized forms for dict[] Fixed a suite of issues involving :class:`.Mapped` use with dictionary types, such as ``Mapped[dict[str, str] | None]``, would not be correctly interpreted in Declarative ORM mappings. Support to correctly "de-optionalize" this type including for lookup in type_annotation_map has been fixed. Fixes: #8777 Change-Id: Iaba90791dea341d00eaff788d40b0a4e48dab02e --- diff --git a/doc/build/changelog/unreleased_20/8777.rst b/doc/build/changelog/unreleased_20/8777.rst new file mode 100644 index 0000000000..b212246052 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8777.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm + :tickets: 8777 + + Fixed a suite of issues involving :class:`.Mapped` use with dictionary + types, such as ``Mapped[dict[str, str] | None]``, would not be correctly + interpreted in Declarative ORM mappings. Support to correctly + "de-optionalize" this type including for lookup in type_annotation_map has + been fixed. diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index c1da267f4d..520d61c4e6 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -437,7 +437,7 @@ class ColumnProperty( try: return ce.info # type: ignore except AttributeError: - return self.prop.info + return self.prop.info # type: ignore def _memoized_attr_expressions(self) -> Sequence[NamedColumn[Any]]: """The full sequence of columns referenced by this @@ -686,10 +686,10 @@ class MappedColumn( ) -> None: sqltype = self.column.type - if is_fwd_ref(argument): + if is_fwd_ref(argument, check_generic=True): assert originating_module is not None argument = de_stringify_annotation( - cls, argument, originating_module + cls, argument, originating_module, include_generic=True ) if is_union(argument): diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index e5bdbaa4f3..06f2d6d1d0 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -2097,13 +2097,30 @@ def _cleanup_mapped_str_annotation( stack.append(g2) break - # stack: ['Mapped', 'List', 'Address'] - if not re.match(r"""^["'].*["']$""", stack[-1]): + # stacks we want to rewrite, that is, quote the last entry which + # we think is a relationship class name: + # + # ['Mapped', 'List', 'Address'] + # ['Mapped', 'A'] + # + # stacks we dont want to rewrite, which are generally MappedColumn + # use cases: + # + # ['Mapped', "'Optional[Dict[str, str]]'"] + # ['Mapped', 'dict[str, str] | None'] + + if ( + # avoid already quoted symbols such as + # ['Mapped', "'Optional[Dict[str, str]]'"] + not re.match(r"""^["'].*["']$""", stack[-1]) + # avoid further generics like Dict[] such as + # ['Mapped', 'dict[str, str] | None'] + and not re.match(r".*\[.*\]", stack[-1]) + ): stripchars = "\"' " stack[-1] = ", ".join( f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",") ) - # stack: ['Mapped', 'List', '"Address"'] annotation = "[".join(stack) + ("]" * (len(stack) - 1)) diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 9eb761eff0..f87ee845b1 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -70,12 +70,37 @@ NoneFwd = ForwardRef("None") typing_get_args = get_args typing_get_origin = get_origin + # copied from TypeShed, required in order to implement # MutableMapping.update() _AnnotationScanType = Union[Type[Any], str, ForwardRef] +class ArgsTypeProcotol(Protocol): + """protocol for types that have ``__args__`` + + there's no public interface for this AFAIK + + """ + + __args__: Tuple[_AnnotationScanType, ...] + + +class GenericProtocol(Protocol[_T]): + """protocol for generic types. + + this since Python.typing _GenericAlias is private + + """ + + __args__: Tuple[_AnnotationScanType, ...] + __origin__: Type[_T] + + def copy_with(self, params: Tuple[_AnnotationScanType, ...]) -> Type[_T]: + ... + + class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): def keys(self) -> Iterable[_KT]: ... @@ -93,6 +118,7 @@ def de_stringify_annotation( annotation: _AnnotationScanType, originating_module: str, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + include_generic: bool = False, ) -> Type[Any]: """Resolve annotations that may be string based into real objects. @@ -119,6 +145,20 @@ def de_stringify_annotation( annotation = str_cleanup_fn(annotation, originating_module) annotation = eval_expression(annotation, originating_module) + + if include_generic and is_generic(annotation): + elements = tuple( + de_stringify_annotation( + cls, + elem, + originating_module, + str_cleanup_fn=str_cleanup_fn, + include_generic=include_generic, + ) + for elem in annotation.__args__ + ) + + return annotation.copy_with(elements) return annotation # type: ignore @@ -172,7 +212,7 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str: def de_stringify_union_elements( cls: Type[Any], - annotation: _AnnotationScanType, + annotation: ArgsTypeProcotol, originating_module: str, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, ) -> Type[Any]: @@ -181,7 +221,7 @@ def de_stringify_union_elements( de_stringify_annotation( cls, anno, originating_module, str_cleanup_fn ) - for anno in annotation.__args__ # type: ignore + for anno in annotation.__args__ ] ) @@ -190,8 +230,19 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: return type_ is not None and typing_get_origin(type_) is Annotated -def is_fwd_ref(type_: _AnnotationScanType) -> bool: - return isinstance(type_, ForwardRef) +def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]: + return hasattr(type_, "__args__") and hasattr(type_, "__origin__") + + +def is_fwd_ref( + type_: _AnnotationScanType, check_generic: bool = False +) -> bool: + if isinstance(type_, ForwardRef): + return True + elif check_generic and is_generic(type_): + return any(is_fwd_ref(arg, True) for arg in type_.__args__) + else: + return False @overload @@ -218,11 +269,12 @@ def de_optionalize_union_types( to not include the ``NoneType``. """ + if is_fwd_ref(type_): return de_optionalize_fwd_ref_union_types(cast(ForwardRef, type_)) elif is_optional(type_): - typ = set(type_.__args__) # type: ignore + typ = set(type_.__args__) typ.discard(NoneType) typ.discard(NoneFwd) @@ -287,14 +339,14 @@ def expand_unions( typ.discard(NoneType) if include_union: - return (type_,) + tuple(typ) + return (type_,) + tuple(typ) # type: ignore else: - return tuple(typ) + return tuple(typ) # type: ignore else: return (type_,) -def is_optional(type_: Any) -> bool: +def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]: return is_origin_of( type_, "Optional", @@ -307,7 +359,7 @@ def is_optional_union(type_: Any) -> bool: return is_optional(type_) and NoneType in typing_get_args(type_) -def is_union(type_: Any) -> bool: +def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]: return is_origin_of(type_, "Union") diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index d66b08f4e0..0f00c2fe46 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -1,6 +1,7 @@ from __future__ import annotations from decimal import Decimal +from typing import Dict from typing import List from typing import Optional from typing import Set @@ -13,6 +14,7 @@ from sqlalchemy import Column from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import Integer +from sqlalchemy import JSON from sqlalchemy import Numeric from sqlalchemy import select from sqlalchemy import String @@ -209,6 +211,62 @@ class MappedColumnTest(_MappedColumnTest): is_(optional_col.type, our_type) is_true(optional_col.nullable) + @testing.combinations( + ("not_optional",), + ("optional",), + ("optional_fwd_ref",), + ("union_none",), + ("pep604", testing.requires.python310), + ("pep604_fwd_ref", testing.requires.python310), + argnames="optional_on_json", + ) + @testing.combinations( + "include_mc_type", "derive_from_anno", argnames="include_mc_type" + ) + def test_optional_styles_nested_brackets( + self, optional_on_json, include_mc_type + ): + class Base(DeclarativeBase): + if testing.requires.python310.enabled: + type_annotation_map = { + Dict[str, str]: JSON, + dict[str, str]: JSON, + } + else: + type_annotation_map = { + Dict[str, str]: JSON, + } + + if include_mc_type == "include_mc_type": + mc = mapped_column(JSON) + else: + mc = mapped_column() + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + + if optional_on_json == "not_optional": + json: Mapped[Dict[str, str]] = mapped_column() # type: ignore + elif optional_on_json == "optional": + json: Mapped[Optional[Dict[str, str]]] = mc + elif optional_on_json == "optional_fwd_ref": + json: Mapped["Optional[Dict[str, str]]"] = mc + elif optional_on_json == "union_none": + json: Mapped[Union[Dict[str, str], None]] = mc + elif optional_on_json == "pep604": + json: Mapped[dict[str, str] | None] = mc + elif optional_on_json == "pep604_fwd_ref": + json: Mapped["dict[str, str] | None"] = mc + + is_(A.__table__.c.json.type._type_affinity, JSON) + if optional_on_json == "not_optional": + is_false(A.__table__.c.json.nullable) + else: + is_true(A.__table__.c.json.nullable) + def test_typ_not_in_cls_namespace(self, decl_base): """test #8742. diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index e3f5e59f4d..527954e16e 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -928,6 +928,62 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_(optional_col.type, our_type) is_true(optional_col.nullable) + @testing.combinations( + ("not_optional",), + ("optional",), + ("optional_fwd_ref",), + ("union_none",), + ("pep604", testing.requires.python310), + ("pep604_fwd_ref", testing.requires.python310), + argnames="optional_on_json", + ) + @testing.combinations( + "include_mc_type", "derive_from_anno", argnames="include_mc_type" + ) + def test_optional_styles_nested_brackets( + self, optional_on_json, include_mc_type + ): + class Base(DeclarativeBase): + if testing.requires.python310.enabled: + type_annotation_map = { + Dict[str, str]: JSON, + dict[str, str]: JSON, + } + else: + type_annotation_map = { + Dict[str, str]: JSON, + } + + if include_mc_type == "include_mc_type": + mc = mapped_column(JSON) + else: + mc = mapped_column() + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column() + + if optional_on_json == "not_optional": + json: Mapped[Dict[str, str]] = mapped_column() # type: ignore + elif optional_on_json == "optional": + json: Mapped[Optional[Dict[str, str]]] = mc + elif optional_on_json == "optional_fwd_ref": + json: Mapped["Optional[Dict[str, str]]"] = mc + elif optional_on_json == "union_none": + json: Mapped[Union[Dict[str, str], None]] = mc + elif optional_on_json == "pep604": + json: Mapped[dict[str, str] | None] = mc + elif optional_on_json == "pep604_fwd_ref": + json: Mapped["dict[str, str] | None"] = mc + + is_(A.__table__.c.json.type._type_affinity, JSON) + if optional_on_json == "not_optional": + is_false(A.__table__.c.json.nullable) + else: + is_true(A.__table__.c.json.nullable) + def test_missing_mapped_lhs(self, decl_base): with expect_annotation_syntax_error("User.name"):