From: Mike Bayer Date: Mon, 30 Dec 2024 18:17:29 +0000 (-0500) Subject: further fixes for _cleanup_mapped_str_annotation X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=fd3d17a30b15cc45ba18efaeb24ecc29b0ea1087;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git further fixes for _cleanup_mapped_str_annotation Fixed issues in type handling within the ``type_annotation_map`` feature which prevented the use of unions, using either pep-604 or ``Union`` syntaxes under future annotations mode, which contained multiple generic types as elements from being correctly resolvable. also adds some further tests to assert that None added into the type map for pep695, typing.NewType etc. sets up nullability on the column Fixes: #12207 Change-Id: I4057694cf35868972db2942721049d79301b19c4 --- diff --git a/doc/build/changelog/unreleased_20/12207.rst b/doc/build/changelog/unreleased_20/12207.rst new file mode 100644 index 0000000000..a6457b90ba --- /dev/null +++ b/doc/build/changelog/unreleased_20/12207.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, orm + :tickets: 12207 + + Fixed issues in type handling within the ``type_annotation_map`` feature + which prevented the use of unions, using either pep-604 or ``Union`` + syntaxes under future annotations mode, which contained multiple generic + types as elements from being correctly resolvable. diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 9c9bd249fa..4c7850971a 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -1308,10 +1308,8 @@ class _ClassScanMapperConfig(_MapperConfig): type(attr_value), required=False, is_dataclass_field=is_dataclass_field, - expect_mapped=expect_mapped - and not is_dataclass, # self.allow_dataclass_fields, + expect_mapped=expect_mapped and not is_dataclass, ) - if extracted is None: # ClassVar can come out here return None @@ -1320,8 +1318,8 @@ class _ClassScanMapperConfig(_MapperConfig): if attr_value is None and not is_literal(extracted_mapped_annotation): for elem in get_args(extracted_mapped_annotation): - if isinstance(elem, str) or is_fwd_ref( - elem, check_generic=True + if is_fwd_ref( + elem, check_generic=True, check_for_plain_string=True ): elem = de_stringify_annotation( self.cls, diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index c6fe71dbb0..2b15e7f2a1 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -43,7 +43,6 @@ from .interfaces import PropComparator from .interfaces import StrategizedProperty from .relationships import RelationshipProperty from .util import de_stringify_annotation -from .util import de_stringify_union_elements from .. import exc as sa_exc from .. import ForeignKey from .. import log @@ -60,7 +59,6 @@ from ..util.typing import includes_none from ..util.typing import is_fwd_ref from ..util.typing import is_pep593 from ..util.typing import is_pep695 -from ..util.typing import is_union from ..util.typing import Self if TYPE_CHECKING: @@ -738,20 +736,14 @@ class MappedColumn( ) -> None: sqltype = self.column.type - if isinstance(argument, str) or is_fwd_ref( - argument, check_generic=True + if is_fwd_ref( + argument, check_generic=True, check_for_plain_string=True ): assert originating_module is not None argument = de_stringify_annotation( cls, argument, originating_module, include_generic=True ) - if is_union(argument): - assert originating_module is not None - argument = de_stringify_union_elements( - cls, argument, originating_module - ) - nullable = includes_none(argument) if not self._has_nullable: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ccabeb4cfd..4dc26dfd80 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -87,9 +87,6 @@ from ..sql.elements import KeyedColumnElement from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots from ..util.typing import de_stringify_annotation as _de_stringify_annotation -from ..util.typing import ( - de_stringify_union_elements as _de_stringify_union_elements, -) from ..util.typing import eval_name_only as _eval_name_only from ..util.typing import fixup_container_fwd_refs from ..util.typing import get_origin @@ -125,7 +122,6 @@ if typing.TYPE_CHECKING: from ..sql.selectable import Selectable from ..sql.visitors import anon_map from ..util.typing import _AnnotationScanType - from ..util.typing import ArgsTypeProtocol _T = TypeVar("_T", bound=Any) @@ -142,7 +138,6 @@ all_cascades = frozenset( ) ) - _de_stringify_partial = functools.partial( functools.partial, locals_=util.immutabledict( @@ -175,23 +170,6 @@ de_stringify_annotation = cast( ) -class _DeStringifyUnionElements(Protocol): - def __call__( - self, - cls: Type[Any], - annotation: ArgsTypeProtocol, - originating_module: str, - *, - str_cleanup_fn: Optional[Callable[[str, str], str]] = None, - ) -> Type[Any]: ... - - -de_stringify_union_elements = cast( - _DeStringifyUnionElements, - _de_stringify_partial(_de_stringify_union_elements), -) - - class _EvalNameOnly(Protocol): def __call__(self, name: str, module_name: str) -> Any: ... @@ -2231,7 +2209,7 @@ def _cleanup_mapped_str_annotation( inner: Optional[Match[str]] - mm = re.match(r"^(.+?)\[(.+)\]$", annotation) + mm = re.match(r"^([^ \|]+?)\[(.+)\]$", annotation) if not mm: return annotation @@ -2271,7 +2249,7 @@ def _cleanup_mapped_str_annotation( while True: stack.append(real_symbol if mm is inner else inner.group(1)) g2 = inner.group(2) - inner = re.match(r"^(.+?)\[(.+)\]$", g2) + inner = re.match(r"^([^ \|]+?)\[(.+)\]$", g2) if inner is None: stack.append(g2) break @@ -2293,8 +2271,10 @@ def _cleanup_mapped_str_annotation( # ['Mapped', "'Optional[Dict[str, str]]'"] not re.match(r"""^["'].*["']$""", stack[-1]) # avoid further generics like Dict[] such as - # ['Mapped', 'dict[str, str] | None'] - and not re.match(r".*\[.*\]", stack[-1]) + # ['Mapped', 'dict[str, str] | None'], + # ['Mapped', 'list[int] | list[str]'], + # ['Mapped', 'Union[list[int], list[str]]'], + and not re.search(r"[\[\]]", stack[-1]) ): stripchars = "\"' " stack[-1] = ", ".join( @@ -2334,6 +2314,11 @@ def _extract_mapped_subtype( return None try: + # destringify the "outside" of the annotation. note we are not + # adding include_generic so it will *not* dig into generic contents, + # which will remain as ForwardRef or plain str under future annotations + # mode. The full destringify happens later when mapped_column goes + # to do a full lookup in the registry type_annotations_map. annotated = de_stringify_annotation( cls, raw_annotation, diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 8565d4d453..9573c52ee6 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -329,28 +329,6 @@ def resolve_name_to_real_class_name(name: str, module_name: str) -> str: return getattr(obj, "__name__", name) -def de_stringify_union_elements( - cls: Type[Any], - annotation: ArgsTypeProtocol, - originating_module: str, - locals_: Mapping[str, Any], - *, - str_cleanup_fn: Optional[Callable[[str, str], str]] = None, -) -> Type[Any]: - return make_union_type( - *[ - de_stringify_annotation( - cls, - anno, - originating_module, - {}, - str_cleanup_fn=str_cleanup_fn, - ) - for anno in annotation.__args__ - ] - ) - - def is_pep593(type_: Optional[Any]) -> bool: return type_ is not None and get_origin(type_) is Annotated @@ -425,12 +403,21 @@ def pep695_values(type_: _AnnotationScanType) -> Set[Any]: def is_fwd_ref( - type_: _AnnotationScanType, check_generic: bool = False + type_: _AnnotationScanType, + check_generic: bool = False, + check_for_plain_string: bool = False, ) -> TypeGuard[ForwardRef]: - if isinstance(type_, ForwardRef): + if check_for_plain_string and isinstance(type_, str): + return True + elif isinstance(type_, ForwardRef): return True elif check_generic and is_generic(type_): - return any(is_fwd_ref(arg, True) for arg in type_.__args__) + return any( + is_fwd_ref( + arg, True, check_for_plain_string=check_for_plain_string + ) + for arg in type_.__args__ + ) else: return False diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index 165f43b42d..9b0d4f334b 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -30,9 +30,11 @@ from sqlalchemy.orm import KeyFuncDict from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship +from sqlalchemy.orm.util import _cleanup_mapped_str_annotation from sqlalchemy.sql import sqltypes from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from .test_typed_mapping import expect_annotation_syntax_error @@ -49,6 +51,89 @@ class M3: pass +class AnnoUtilTest(fixtures.TestBase): + @testing.combinations( + ("Mapped[Address]", 'Mapped["Address"]'), + ('Mapped["Address"]', 'Mapped["Address"]'), + ("Mapped['Address']", "Mapped['Address']"), + ("Mapped[Address | None]", 'Mapped["Address | None"]'), + ("Mapped[None | Address]", 'Mapped["None | Address"]'), + ('Mapped["Address | None"]', 'Mapped["Address | None"]'), + ("Mapped['None | Address']", "Mapped['None | Address']"), + ('Mapped["Address" | "None"]', 'Mapped["Address" | "None"]'), + ('Mapped["None" | "Address"]', 'Mapped["None" | "Address"]'), + ("Mapped[A_]", 'Mapped["A_"]'), + ("Mapped[_TypingLiteral]", 'Mapped["_TypingLiteral"]'), + ("Mapped[datetime.datetime]", 'Mapped["datetime.datetime"]'), + ("Mapped[List[Edge]]", 'Mapped[List["Edge"]]'), + ( + "Mapped[collections.abc.MutableSequence[B]]", + 'Mapped[collections.abc.MutableSequence["B"]]', + ), + ("Mapped[typing.Sequence[B]]", 'Mapped[typing.Sequence["B"]]'), + ("Mapped[dict[str, str]]", 'Mapped[dict["str", "str"]]'), + ("Mapped[Dict[str, str]]", 'Mapped[Dict["str", "str"]]'), + ("Mapped[list[str]]", 'Mapped[list["str"]]'), + ("Mapped[dict[str, str] | None]", "Mapped[dict[str, str] | None]"), + ("Mapped[Optional[anno_str_mc]]", 'Mapped[Optional["anno_str_mc"]]'), + ( + "Mapped[Optional[Dict[str, str]]]", + 'Mapped[Optional[Dict["str", "str"]]]', + ), + ( + "Mapped[Optional[Union[Decimal, float]]]", + 'Mapped[Optional[Union["Decimal", "float"]]]', + ), + ( + "Mapped[Optional[Union[list[int], list[str]]]]", + "Mapped[Optional[Union[list[int], list[str]]]]", + ), + ("Mapped[TestType[str]]", 'Mapped[TestType["str"]]'), + ("Mapped[TestType[str, str]]", 'Mapped[TestType["str", "str"]]'), + ("Mapped[Union[A, None]]", 'Mapped[Union["A", "None"]]'), + ("Mapped[Union[Decimal, float]]", 'Mapped[Union["Decimal", "float"]]'), + ( + "Mapped[Union[Decimal, float, None]]", + 'Mapped[Union["Decimal", "float", "None"]]', + ), + ( + "Mapped[Union[Dict[str, str], None]]", + "Mapped[Union[Dict[str, str], None]]", + ), + ("Mapped[Union[float, Decimal]]", 'Mapped[Union["float", "Decimal"]]'), + ( + "Mapped[Union[list[int], list[str]]]", + "Mapped[Union[list[int], list[str]]]", + ), + ( + "Mapped[Union[list[int], list[str], None]]", + "Mapped[Union[list[int], list[str], None]]", + ), + ( + "Mapped[Union[None, Dict[str, str]]]", + "Mapped[Union[None, Dict[str, str]]]", + ), + ( + "Mapped[Union[None, list[int], list[str]]]", + "Mapped[Union[None, list[int], list[str]]]", + ), + ("Mapped[A | None]", 'Mapped["A | None"]'), + ("Mapped[Decimal | float]", 'Mapped["Decimal | float"]'), + ("Mapped[Decimal | float | None]", 'Mapped["Decimal | float | None"]'), + ( + "Mapped[list[int] | list[str] | None]", + "Mapped[list[int] | list[str] | None]", + ), + ("Mapped[None | dict[str, str]]", "Mapped[None | dict[str, str]]"), + ( + "Mapped[None | list[int] | list[str]]", + "Mapped[None | list[int] | list[str]]", + ), + ) + def test_cleanup_mapped_str_annotation(self, given, expected): + eq_(_cleanup_mapped_str_annotation(given, __name__), expected) + + class MappedColumnTest(_MappedColumnTest): def test_fully_qualified_mapped_name(self, decl_base): """test #8853, regression caused by #8759 ;) diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index e6cbf1d1fe..a9cd459443 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -116,8 +116,9 @@ _UnionTypeAlias: TypeAlias = Union[_SomeDict1, _SomeDict2] _StrTypeAlias: TypeAlias = str -_StrPep695: TypeAlias = str -_UnionPep695: TypeAlias = Union[_SomeDict1, _SomeDict2] +if TYPE_CHECKING: + _StrPep695: TypeAlias = str + _UnionPep695: TypeAlias = Union[_SomeDict1, _SomeDict2] _TypingLiteral = typing.Literal["a", "b"] _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] @@ -157,6 +158,17 @@ type _JsonPep695 = _JsonPep604 ) +def make_pep695_type(name, definition): + lcls = {} + exec( + f""" +type {name} = {definition} +""", + lcls, + ) + return lcls[name] + + def expect_annotation_syntax_error(name): return expect_raises_message( sa_exc.ArgumentError, @@ -862,6 +874,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): "optional", "optional_union", "optional_union_604", + "union_newtype", + "union_null_newtype", + "union_695", + "union_null_695", ], ) @testing.variation("in_map", ["yes", "no", "value"]) @@ -886,12 +902,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): tat = TypeAliasType("tat", Optional[Union[str, int]]) elif option.optional_union_604: tat = TypeAliasType("tat", Optional[str | int]) + elif option.union_newtype: + # this seems to be illegal for typing but "works" + tat = NewType("tat", Union[str, int]) + elif option.union_null_newtype: + # this seems to be illegal for typing but "works" + tat = NewType("tat", Union[str, int, None]) + elif option.union_695: + tat = make_pep695_type("tat", str | int) + elif option.union_null_695: + tat = make_pep695_type("tat", str | int | None) else: option.fail() if in_map.yes: decl_base.registry.update_type_annotation_map({tat: String(99)}) - elif in_map.value: + elif in_map.value and "newtype" not in option.name: decl_base.registry.update_type_annotation_map( {tat.__value__: String(99)} ) @@ -907,7 +933,12 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): if in_map.yes: col = declare() length = 99 - elif in_map.value or option.optional or option.plain: + elif ( + in_map.value + and "newtype" not in option.name + or option.optional + or option.plain + ): with expect_deprecated( "Matching the provided TypeAliasType 'tat' on its " "resolved value without matching it in the " @@ -1950,6 +1981,13 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): refer_union: Mapped[UnionType] refer_union_optional: Mapped[Optional[UnionType]] + # py38, 37 does not automatically flatten unions, add extra tests + # for this. maintain these in order to catch future regressions + # in the behavior of ``Union`` + unflat_union_optional_data: Mapped[ + Union[Union[Decimal, float, None], None] + ] = mapped_column() + float_data: Mapped[float] = mapped_column() decimal_data: Mapped[Decimal] = mapped_column() @@ -1973,6 +2011,7 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): ("reverse_u_optional_data", True), ("refer_union", "null" in union.name), ("refer_union_optional", True), + ("unflat_union_optional_data", True), ] if compat.py310: info += [ @@ -2039,36 +2078,47 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(A.__table__.c.json1.nullable) is_false(A.__table__.c.json2.nullable) - @testing.combinations( - ("not_optional",), - ("optional",), - ("optional_fwd_ref",), - ("union_none",), - ("pep604", testing.requires.python310), - ("pep604_fwd_ref", testing.requires.python310), - argnames="optional_on_json", + @testing.variation( + "option", + [ + "not_optional", + "optional", + "optional_fwd_ref", + "union_none", + ("pep604", testing.requires.python310), + ("pep604_fwd_ref", testing.requires.python310), + ], ) + @testing.variation("brackets", ["oneset", "twosets"]) @testing.combinations( "include_mc_type", "derive_from_anno", argnames="include_mc_type" ) def test_optional_styles_nested_brackets( - self, optional_on_json, include_mc_type + self, option, brackets, include_mc_type ): + """composed types test, includes tests that were added later for + #12207""" + class Base(DeclarativeBase): if testing.requires.python310.enabled: type_annotation_map = { - Dict[str, str]: JSON, - dict[str, str]: JSON, + Dict[str, Decimal]: JSON, + dict[str, Decimal]: JSON, + Union[List[int], List[str]]: JSON, + list[int] | list[str]: JSON, } else: type_annotation_map = { - Dict[str, str]: JSON, + Dict[str, Decimal]: JSON, + Union[List[int], List[str]]: JSON, } if include_mc_type == "include_mc_type": mc = mapped_column(JSON) + mc2 = mapped_column(JSON) else: mc = mapped_column() + mc2 = mapped_column() class A(Base): __tablename__ = "a" @@ -2076,21 +2126,67 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped[str] = mapped_column() - if optional_on_json == "not_optional": - json: Mapped[Dict[str, str]] = mapped_column() # type: ignore - elif optional_on_json == "optional": - json: Mapped[Optional[Dict[str, str]]] = mc - elif optional_on_json == "optional_fwd_ref": - json: Mapped["Optional[Dict[str, str]]"] = mc - elif optional_on_json == "union_none": - json: Mapped[Union[Dict[str, str], None]] = mc - elif optional_on_json == "pep604": - json: Mapped[dict[str, str] | None] = mc - elif optional_on_json == "pep604_fwd_ref": - json: Mapped["dict[str, str] | None"] = mc + if brackets.oneset: + if option.not_optional: + json: Mapped[Dict[str, Decimal]] = mapped_column() # type: ignore # noqa: E501 + if testing.requires.python310.enabled: + json2: Mapped[dict[str, Decimal]] = mapped_column() # type: ignore # noqa: E501 + elif option.optional: + json: Mapped[Optional[Dict[str, Decimal]]] = mc + if testing.requires.python310.enabled: + json2: Mapped[Optional[dict[str, Decimal]]] = mc2 + elif option.optional_fwd_ref: + json: Mapped["Optional[Dict[str, Decimal]]"] = mc + if testing.requires.python310.enabled: + json2: Mapped["Optional[dict[str, Decimal]]"] = mc2 + elif option.union_none: + json: Mapped[Union[Dict[str, Decimal], None]] = mc + json2: Mapped[Union[None, Dict[str, Decimal]]] = mc2 + elif option.pep604: + json: Mapped[dict[str, Decimal] | None] = mc + if testing.requires.python310.enabled: + json2: Mapped[None | dict[str, Decimal]] = mc2 + elif option.pep604_fwd_ref: + json: Mapped["dict[str, Decimal] | None"] = mc + if testing.requires.python310.enabled: + json2: Mapped["None | dict[str, Decimal]"] = mc2 + elif brackets.twosets: + if option.not_optional: + json: Mapped[Union[List[int], List[str]]] = mapped_column() # type: ignore # noqa: E501 + elif option.optional: + json: Mapped[Optional[Union[List[int], List[str]]]] = mc + if testing.requires.python310.enabled: + json2: Mapped[ + Optional[Union[list[int], list[str]]] + ] = mc2 + elif option.optional_fwd_ref: + json: Mapped["Optional[Union[List[int], List[str]]]"] = mc + if testing.requires.python310.enabled: + json2: Mapped[ + "Optional[Union[list[int], list[str]]]" + ] = mc2 + elif option.union_none: + json: Mapped[Union[List[int], List[str], None]] = mc + if testing.requires.python310.enabled: + json2: Mapped[Union[None, list[int], list[str]]] = mc2 + elif option.pep604: + json: Mapped[list[int] | list[str] | None] = mc + json2: Mapped[None | list[int] | list[str]] = mc2 + elif option.pep604_fwd_ref: + json: Mapped["list[int] | list[str] | None"] = mc + json2: Mapped["None | list[int] | list[str]"] = mc2 + else: + brackets.fail() is_(A.__table__.c.json.type._type_affinity, JSON) - if optional_on_json == "not_optional": + if hasattr(A, "json2"): + is_(A.__table__.c.json2.type._type_affinity, JSON) + if option.not_optional: + is_false(A.__table__.c.json2.nullable) + else: + is_true(A.__table__.c.json2.nullable) + + if option.not_optional: is_false(A.__table__.c.json.nullable) else: is_true(A.__table__.c.json.nullable) @@ -3147,7 +3243,7 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): back_populates="bs", primaryjoin=a_id == A.id ) elif optional_on_m2o == "union_none": - a: Mapped["Union[A, None]"] = relationship( + a: Mapped[Union[A, None]] = relationship( back_populates="bs", primaryjoin=a_id == A.id ) elif optional_on_m2o == "pep604": diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 558d646430..1a90eadd9d 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -107,8 +107,9 @@ _UnionTypeAlias: TypeAlias = Union[_SomeDict1, _SomeDict2] _StrTypeAlias: TypeAlias = str -_StrPep695: TypeAlias = str -_UnionPep695: TypeAlias = Union[_SomeDict1, _SomeDict2] +if TYPE_CHECKING: + _StrPep695: TypeAlias = str + _UnionPep695: TypeAlias = Union[_SomeDict1, _SomeDict2] _TypingLiteral = typing.Literal["a", "b"] _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"] @@ -148,6 +149,17 @@ type _JsonPep695 = _JsonPep604 ) +def make_pep695_type(name, definition): + lcls = {} + exec( + f""" +type {name} = {definition} +""", + lcls, + ) + return lcls[name] + + def expect_annotation_syntax_error(name): return expect_raises_message( sa_exc.ArgumentError, @@ -853,6 +865,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): "optional", "optional_union", "optional_union_604", + "union_newtype", + "union_null_newtype", + "union_695", + "union_null_695", ], ) @testing.variation("in_map", ["yes", "no", "value"]) @@ -877,12 +893,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): tat = TypeAliasType("tat", Optional[Union[str, int]]) elif option.optional_union_604: tat = TypeAliasType("tat", Optional[str | int]) + elif option.union_newtype: + # this seems to be illegal for typing but "works" + tat = NewType("tat", Union[str, int]) + elif option.union_null_newtype: + # this seems to be illegal for typing but "works" + tat = NewType("tat", Union[str, int, None]) + elif option.union_695: + tat = make_pep695_type("tat", str | int) + elif option.union_null_695: + tat = make_pep695_type("tat", str | int | None) else: option.fail() if in_map.yes: decl_base.registry.update_type_annotation_map({tat: String(99)}) - elif in_map.value: + elif in_map.value and "newtype" not in option.name: decl_base.registry.update_type_annotation_map( {tat.__value__: String(99)} ) @@ -898,7 +924,12 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): if in_map.yes: col = declare() length = 99 - elif in_map.value or option.optional or option.plain: + elif ( + in_map.value + and "newtype" not in option.name + or option.optional + or option.plain + ): with expect_deprecated( "Matching the provided TypeAliasType 'tat' on its " "resolved value without matching it in the " @@ -1941,6 +1972,13 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): refer_union: Mapped[UnionType] refer_union_optional: Mapped[Optional[UnionType]] + # py38, 37 does not automatically flatten unions, add extra tests + # for this. maintain these in order to catch future regressions + # in the behavior of ``Union`` + unflat_union_optional_data: Mapped[ + Union[Union[Decimal, float, None], None] + ] = mapped_column() + float_data: Mapped[float] = mapped_column() decimal_data: Mapped[Decimal] = mapped_column() @@ -1964,6 +2002,7 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): ("reverse_u_optional_data", True), ("refer_union", "null" in union.name), ("refer_union_optional", True), + ("unflat_union_optional_data", True), ] if compat.py310: info += [ @@ -2030,36 +2069,47 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(A.__table__.c.json1.nullable) is_false(A.__table__.c.json2.nullable) - @testing.combinations( - ("not_optional",), - ("optional",), - ("optional_fwd_ref",), - ("union_none",), - ("pep604", testing.requires.python310), - ("pep604_fwd_ref", testing.requires.python310), - argnames="optional_on_json", + @testing.variation( + "option", + [ + "not_optional", + "optional", + "optional_fwd_ref", + "union_none", + ("pep604", testing.requires.python310), + ("pep604_fwd_ref", testing.requires.python310), + ], ) + @testing.variation("brackets", ["oneset", "twosets"]) @testing.combinations( "include_mc_type", "derive_from_anno", argnames="include_mc_type" ) def test_optional_styles_nested_brackets( - self, optional_on_json, include_mc_type + self, option, brackets, include_mc_type ): + """composed types test, includes tests that were added later for + #12207""" + class Base(DeclarativeBase): if testing.requires.python310.enabled: type_annotation_map = { - Dict[str, str]: JSON, - dict[str, str]: JSON, + Dict[str, Decimal]: JSON, + dict[str, Decimal]: JSON, + Union[List[int], List[str]]: JSON, + list[int] | list[str]: JSON, } else: type_annotation_map = { - Dict[str, str]: JSON, + Dict[str, Decimal]: JSON, + Union[List[int], List[str]]: JSON, } if include_mc_type == "include_mc_type": mc = mapped_column(JSON) + mc2 = mapped_column(JSON) else: mc = mapped_column() + mc2 = mapped_column() class A(Base): __tablename__ = "a" @@ -2067,21 +2117,67 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): id: Mapped[int] = mapped_column(primary_key=True) data: Mapped[str] = mapped_column() - if optional_on_json == "not_optional": - json: Mapped[Dict[str, str]] = mapped_column() # type: ignore - elif optional_on_json == "optional": - json: Mapped[Optional[Dict[str, str]]] = mc - elif optional_on_json == "optional_fwd_ref": - json: Mapped["Optional[Dict[str, str]]"] = mc - elif optional_on_json == "union_none": - json: Mapped[Union[Dict[str, str], None]] = mc - elif optional_on_json == "pep604": - json: Mapped[dict[str, str] | None] = mc - elif optional_on_json == "pep604_fwd_ref": - json: Mapped["dict[str, str] | None"] = mc + if brackets.oneset: + if option.not_optional: + json: Mapped[Dict[str, Decimal]] = mapped_column() # type: ignore # noqa: E501 + if testing.requires.python310.enabled: + json2: Mapped[dict[str, Decimal]] = mapped_column() # type: ignore # noqa: E501 + elif option.optional: + json: Mapped[Optional[Dict[str, Decimal]]] = mc + if testing.requires.python310.enabled: + json2: Mapped[Optional[dict[str, Decimal]]] = mc2 + elif option.optional_fwd_ref: + json: Mapped["Optional[Dict[str, Decimal]]"] = mc + if testing.requires.python310.enabled: + json2: Mapped["Optional[dict[str, Decimal]]"] = mc2 + elif option.union_none: + json: Mapped[Union[Dict[str, Decimal], None]] = mc + json2: Mapped[Union[None, Dict[str, Decimal]]] = mc2 + elif option.pep604: + json: Mapped[dict[str, Decimal] | None] = mc + if testing.requires.python310.enabled: + json2: Mapped[None | dict[str, Decimal]] = mc2 + elif option.pep604_fwd_ref: + json: Mapped["dict[str, Decimal] | None"] = mc + if testing.requires.python310.enabled: + json2: Mapped["None | dict[str, Decimal]"] = mc2 + elif brackets.twosets: + if option.not_optional: + json: Mapped[Union[List[int], List[str]]] = mapped_column() # type: ignore # noqa: E501 + elif option.optional: + json: Mapped[Optional[Union[List[int], List[str]]]] = mc + if testing.requires.python310.enabled: + json2: Mapped[ + Optional[Union[list[int], list[str]]] + ] = mc2 + elif option.optional_fwd_ref: + json: Mapped["Optional[Union[List[int], List[str]]]"] = mc + if testing.requires.python310.enabled: + json2: Mapped[ + "Optional[Union[list[int], list[str]]]" + ] = mc2 + elif option.union_none: + json: Mapped[Union[List[int], List[str], None]] = mc + if testing.requires.python310.enabled: + json2: Mapped[Union[None, list[int], list[str]]] = mc2 + elif option.pep604: + json: Mapped[list[int] | list[str] | None] = mc + json2: Mapped[None | list[int] | list[str]] = mc2 + elif option.pep604_fwd_ref: + json: Mapped["list[int] | list[str] | None"] = mc + json2: Mapped["None | list[int] | list[str]"] = mc2 + else: + brackets.fail() is_(A.__table__.c.json.type._type_affinity, JSON) - if optional_on_json == "not_optional": + if hasattr(A, "json2"): + is_(A.__table__.c.json2.type._type_affinity, JSON) + if option.not_optional: + is_false(A.__table__.c.json2.nullable) + else: + is_true(A.__table__.c.json2.nullable) + + if option.not_optional: is_false(A.__table__.c.json.nullable) else: is_true(A.__table__.c.json.nullable) @@ -3138,7 +3234,7 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): back_populates="bs", primaryjoin=a_id == A.id ) elif optional_on_m2o == "union_none": - a: Mapped["Union[A, None]"] = relationship( + a: Mapped[Union[A, None]] = relationship( back_populates="bs", primaryjoin=a_id == A.id ) elif optional_on_m2o == "pep604":