]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
dont match partial types in type_annotation_map
authorFrazer McLean <frazer@frazermclean.co.uk>
Thu, 3 Oct 2024 22:21:12 +0000 (18:21 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 10 Dec 2024 23:33:53 +0000 (18:33 -0500)
Fixed issue regarding ``Union`` types that would be present in the
:paramref:`_orm.registry.type_annotation_map` of a :class:`_orm.registry`
or declarative base class, where a ``Mapped[]`` element that included one
of the subtypes present in that ``Union`` would be matched to that entry,
potentially ignoring other entries that matched exactly.   The correct
behavior now takes place such that an entry should only match in
``type_annotation_map`` exactly, as a ``Union`` type is a self-contained
type. For example, an attribute with ``Mapped[float]`` would previously
match to a ``type_annotation_map`` entry ``Union[float, Decimal]``; this
will no longer match and will now only match to an entry that states
``float``. Pull request courtesy Frazer McLean.

Fixes #11370
Closes: #11942
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/11942
Pull-request-sha: 21a3d1971a04e117a557f6e6bac77bce9f6bb0a9

Change-Id: I3467be00f8fa8bd011dd4805a77a3b80ff74a215
(cherry picked from commit 40c30ec44616223216737327f97bac66a13eedee)

doc/build/changelog/unreleased_20/11370.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/util/typing.py
test/base/test_utils.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/11370.rst b/doc/build/changelog/unreleased_20/11370.rst
new file mode 100644 (file)
index 0000000..56e8553
--- /dev/null
@@ -0,0 +1,15 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 11370
+
+    Fixed issue regarding ``Union`` types that would be present in the
+    :paramref:`_orm.registry.type_annotation_map` of a :class:`_orm.registry`
+    or declarative base class, where a ``Mapped[]`` element that included one
+    of the subtypes present in that ``Union`` would be matched to that entry,
+    potentially ignoring other entries that matched exactly.   The correct
+    behavior now takes place such that an entry should only match in
+    ``type_annotation_map`` exactly, as a ``Union`` type is a self-contained
+    type. For example, an attribute with ``Mapped[float]`` would previously
+    match to a ``type_annotation_map`` entry ``Union[float, Decimal]``; this
+    will no longer match and will now only match to an entry that states
+    ``float``. Pull request courtesy Frazer McLean.
index 311a9bd4a5121d4734eaaa2bccfa52508efe70dd..718cf72516bd6773f4e3b7455886c1aa9fa9977a 100644 (file)
@@ -73,6 +73,7 @@ from ..util import hybridmethod
 from ..util import hybridproperty
 from ..util import typing as compat_typing
 from ..util.typing import CallableReference
+from ..util.typing import de_optionalize_union_types
 from ..util.typing import flatten_newtype
 from ..util.typing import is_generic
 from ..util.typing import is_literal
@@ -1225,11 +1226,8 @@ class registry:
 
         self.type_annotation_map.update(
             {
-                sub_type: sqltype
+                de_optionalize_union_types(typ): sqltype
                 for typ, sqltype in type_annotation_map.items()
-                for sub_type in compat_typing.expand_unions(
-                    typ, include_union=True, discard_none=True
-                )
             }
         )
 
index a3df97770542af215626b414be16cfc43f76d98e..bd1ebd4c01380e5b5351f0a1a093e3154ea06da4 100644 (file)
@@ -422,6 +422,9 @@ def de_optionalize_union_types(
 
     """
 
+    while is_pep695(type_):
+        type_ = type_.__value__
+
     if is_fwd_ref(type_):
         return de_optionalize_fwd_ref_union_types(type_)
 
@@ -478,26 +481,6 @@ def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
     return cast(Any, Union).__getitem__(types)  # type: ignore
 
 
-def expand_unions(
-    type_: Type[Any], include_union: bool = False, discard_none: bool = False
-) -> Tuple[Type[Any], ...]:
-    """Return a type as a tuple of individual types, expanding for
-    ``Union`` types."""
-
-    if is_union(type_):
-        typ = set(type_.__args__)
-
-        if discard_none:
-            typ.discard(NoneType)
-
-        if include_union:
-            return (type_,) + tuple(typ)  # type: ignore
-        else:
-            return tuple(typ)  # type: ignore
-    else:
-        return (type_,)
-
-
 def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
     return is_origin_of(
         type_,
@@ -512,7 +495,7 @@ def is_optional_union(type_: Any) -> bool:
 
 
 def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
-    return is_origin_of(type_, "Union")
+    return is_origin_of(type_, "Union", "UnionType")
 
 
 def is_origin_of_cls(
index de8712c852343b54c356695ec7f0413f4ad7bcd3..85c419e94e81f18eac042a7219ba901ca36e2c1c 100644 (file)
@@ -4,6 +4,9 @@ import inspect
 from pathlib import Path
 import pickle
 import sys
+import typing
+
+import typing_extensions
 
 from sqlalchemy import exc
 from sqlalchemy import sql
@@ -39,6 +42,7 @@ from sqlalchemy.util import WeakSequence
 from sqlalchemy.util._collections import merge_lists_w_ordering
 from sqlalchemy.util._has_cy import _import_cy_extensions
 from sqlalchemy.util._has_cy import HAS_CYEXTENSION
+from sqlalchemy.util.typing import is_union
 
 
 class WeakSequenceTest(fixtures.TestBase):
@@ -3630,3 +3634,11 @@ class CyExtensionTest(fixtures.TestBase):
             for f in cython_files
         }
         eq_({m.__name__ for m in ext}, set(names))
+
+
+class TypingTest(fixtures.TestBase):
+    def test_is_union(self):
+        assert is_union(typing.Union[str, int])
+        assert is_union(typing_extensions.Union[str, int])
+        if compat.py310:
+            assert is_union(str | int)
index 2aa8f0f0b0f0944771957f1dfcf96945bec3502c..6bf7d02c56c2744651e68a165c66b59fc7a2fb69 100644 (file)
@@ -33,6 +33,7 @@ import typing_extensions
 from typing_extensions import get_args as get_args
 from typing_extensions import Literal as Literal
 from typing_extensions import TypeAlias as TypeAlias
+from typing_extensions import TypeAliasType
 from typing_extensions import TypedDict
 
 from sqlalchemy import BIGINT
@@ -41,6 +42,7 @@ from sqlalchemy import Column
 from sqlalchemy import DateTime
 from sqlalchemy import exc
 from sqlalchemy import exc as sa_exc
+from sqlalchemy import Float
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Identity
@@ -94,6 +96,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
+from sqlalchemy.testing import skip_test
 from sqlalchemy.testing import Variation
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.util import compat
@@ -124,6 +127,19 @@ if compat.py38:
     _TypingLiteral = typing.Literal["a", "b"]
 _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"]
 
+_JsonPrimitive: TypeAlias = Union[str, int, float, bool, None]
+_JsonObject: TypeAlias = Dict[str, "_Json"]
+_JsonArray: TypeAlias = List["_Json"]
+_Json: TypeAlias = Union[_JsonObject, _JsonArray, _JsonPrimitive]
+
+if compat.py310:
+    _JsonPrimitivePep604: TypeAlias = str | int | float | bool | None
+    _JsonObjectPep604: TypeAlias = dict[str, "_JsonPep604"]
+    _JsonArrayPep604: TypeAlias = list["_JsonPep604"]
+    _JsonPep604: TypeAlias = (
+        _JsonObjectPep604 | _JsonArrayPep604 | _JsonPrimitivePep604
+    )
+
 if compat.py312:
     exec(
         """
@@ -1708,11 +1724,30 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             is_(getattr(Element.__table__.c.data, paramname), override_value)
 
-    def test_unions(self):
+    @testing.variation("union", ["union", "pep604"])
+    @testing.variation("typealias", ["legacy", "pep695"])
+    def test_unions(self, union, typealias):
         our_type = Numeric(10, 2)
 
+        if union.union:
+            UnionType = Union[float, Decimal]
+        elif union.pep604:
+            if not compat.py310:
+                skip_test("Required Python 3.10")
+            UnionType = float | Decimal
+        else:
+            union.fail()
+
+        if typealias.legacy:
+            UnionTypeAlias = UnionType
+        elif typealias.pep695:
+            # same as type UnionTypeAlias = UnionType
+            UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType)
+        else:
+            typealias.fail()
+
         class Base(DeclarativeBase):
-            type_annotation_map = {Union[float, Decimal]: our_type}
+            type_annotation_map = {UnionTypeAlias: our_type}
 
         class User(Base):
             __tablename__ = "users"
@@ -1753,6 +1788,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                     mapped_column()
                 )
 
+            if compat.py312:
+                MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal)
+                pep695_data: Mapped[MyTypeAlias] = mapped_column()
+
         is_(User.__table__.c.data.type, our_type)
         is_false(User.__table__.c.data.nullable)
         is_(User.__table__.c.reverse_data.type, our_type)
@@ -1764,8 +1803,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(User.__table__.c.reverse_optional_data.nullable)
         is_true(User.__table__.c.reverse_u_optional_data.nullable)
 
-        is_(User.__table__.c.float_data.type, our_type)
-        is_(User.__table__.c.decimal_data.type, our_type)
+        is_true(isinstance(User.__table__.c.float_data.type, Float))
+        is_true(isinstance(User.__table__.c.float_data.type, Numeric))
+        is_not(User.__table__.c.decimal_data.type, our_type)
 
         if compat.py310:
             for suffix in ("", "_fwd"):
@@ -1779,6 +1819,57 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 is_(optional_col.type, our_type)
                 is_true(optional_col.nullable)
 
+        if compat.py312:
+            is_(User.__table__.c.pep695_data.type, our_type)
+
+    @testing.variation("union", ["union", "pep604"])
+    def test_optional_in_annotation_map(self, union):
+        """SQLAlchemy's behaviour is clear: an optional type means the column
+        is inferred as nullable. Some types which a user may want to put in the
+        type annotation map are already optional. JSON is a good example
+        because without any constraint, the type can be None via JSON null or
+        SQL NULL.
+
+        By permitting optional types in the type annotation map, everything
+        just works, and mapped_column(nullable=False) is available if desired.
+
+        See issue #11370
+        """
+
+        class Base(DeclarativeBase):
+            if union.union:
+                type_annotation_map = {
+                    _Json: JSON,
+                }
+            elif union.pep604:
+                if not compat.py310:
+                    skip_test("Requires Python 3.10+")
+                type_annotation_map = {
+                    _JsonPep604: JSON,
+                }
+            else:
+                union.fail()
+
+        class A(Base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            if union.union:
+                json1: Mapped[_Json]
+                json2: Mapped[_Json] = mapped_column(nullable=False)
+            elif union.pep604:
+                if not compat.py310:
+                    skip_test("Requires Python 3.10+")
+                json1: Mapped[_JsonPep604]
+                json2: Mapped[_JsonPep604] = mapped_column(nullable=False)
+            else:
+                union.fail()
+
+        is_(A.__table__.c.json1.type._type_affinity, JSON)
+        is_(A.__table__.c.json2.type._type_affinity, JSON)
+        is_true(A.__table__.c.json1.nullable)
+        is_false(A.__table__.c.json2.nullable)
+
     @testing.combinations(
         ("not_optional",),
         ("optional",),
index b50573fa12fb4b6b924065060b21b62dd91b2fe2..929041ccfbff2f4e943df10cd78414002080c0cf 100644 (file)
@@ -24,6 +24,7 @@ import typing_extensions
 from typing_extensions import get_args as get_args
 from typing_extensions import Literal as Literal
 from typing_extensions import TypeAlias as TypeAlias
+from typing_extensions import TypeAliasType
 from typing_extensions import TypedDict
 
 from sqlalchemy import BIGINT
@@ -32,6 +33,7 @@ from sqlalchemy import Column
 from sqlalchemy import DateTime
 from sqlalchemy import exc
 from sqlalchemy import exc as sa_exc
+from sqlalchemy import Float
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Identity
@@ -85,6 +87,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
+from sqlalchemy.testing import skip_test
 from sqlalchemy.testing import Variation
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.util import compat
@@ -115,6 +118,19 @@ if compat.py38:
     _TypingLiteral = typing.Literal["a", "b"]
 _TypingExtensionsLiteral = typing_extensions.Literal["a", "b"]
 
+_JsonPrimitive: TypeAlias = Union[str, int, float, bool, None]
+_JsonObject: TypeAlias = Dict[str, "_Json"]
+_JsonArray: TypeAlias = List["_Json"]
+_Json: TypeAlias = Union[_JsonObject, _JsonArray, _JsonPrimitive]
+
+if compat.py310:
+    _JsonPrimitivePep604: TypeAlias = str | int | float | bool | None
+    _JsonObjectPep604: TypeAlias = dict[str, "_JsonPep604"]
+    _JsonArrayPep604: TypeAlias = list["_JsonPep604"]
+    _JsonPep604: TypeAlias = (
+        _JsonObjectPep604 | _JsonArrayPep604 | _JsonPrimitivePep604
+    )
+
 if compat.py312:
     exec(
         """
@@ -1699,11 +1715,30 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             is_(getattr(Element.__table__.c.data, paramname), override_value)
 
-    def test_unions(self):
+    @testing.variation("union", ["union", "pep604"])
+    @testing.variation("typealias", ["legacy", "pep695"])
+    def test_unions(self, union, typealias):
         our_type = Numeric(10, 2)
 
+        if union.union:
+            UnionType = Union[float, Decimal]
+        elif union.pep604:
+            if not compat.py310:
+                skip_test("Required Python 3.10")
+            UnionType = float | Decimal
+        else:
+            union.fail()
+
+        if typealias.legacy:
+            UnionTypeAlias = UnionType
+        elif typealias.pep695:
+            # same as type UnionTypeAlias = UnionType
+            UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType)
+        else:
+            typealias.fail()
+
         class Base(DeclarativeBase):
-            type_annotation_map = {Union[float, Decimal]: our_type}
+            type_annotation_map = {UnionTypeAlias: our_type}
 
         class User(Base):
             __tablename__ = "users"
@@ -1744,6 +1779,10 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                     mapped_column()
                 )
 
+            if compat.py312:
+                MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal)
+                pep695_data: Mapped[MyTypeAlias] = mapped_column()
+
         is_(User.__table__.c.data.type, our_type)
         is_false(User.__table__.c.data.nullable)
         is_(User.__table__.c.reverse_data.type, our_type)
@@ -1755,8 +1794,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(User.__table__.c.reverse_optional_data.nullable)
         is_true(User.__table__.c.reverse_u_optional_data.nullable)
 
-        is_(User.__table__.c.float_data.type, our_type)
-        is_(User.__table__.c.decimal_data.type, our_type)
+        is_true(isinstance(User.__table__.c.float_data.type, Float))
+        is_true(isinstance(User.__table__.c.float_data.type, Numeric))
+        is_not(User.__table__.c.decimal_data.type, our_type)
 
         if compat.py310:
             for suffix in ("", "_fwd"):
@@ -1770,6 +1810,57 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 is_(optional_col.type, our_type)
                 is_true(optional_col.nullable)
 
+        if compat.py312:
+            is_(User.__table__.c.pep695_data.type, our_type)
+
+    @testing.variation("union", ["union", "pep604"])
+    def test_optional_in_annotation_map(self, union):
+        """SQLAlchemy's behaviour is clear: an optional type means the column
+        is inferred as nullable. Some types which a user may want to put in the
+        type annotation map are already optional. JSON is a good example
+        because without any constraint, the type can be None via JSON null or
+        SQL NULL.
+
+        By permitting optional types in the type annotation map, everything
+        just works, and mapped_column(nullable=False) is available if desired.
+
+        See issue #11370
+        """
+
+        class Base(DeclarativeBase):
+            if union.union:
+                type_annotation_map = {
+                    _Json: JSON,
+                }
+            elif union.pep604:
+                if not compat.py310:
+                    skip_test("Requires Python 3.10+")
+                type_annotation_map = {
+                    _JsonPep604: JSON,
+                }
+            else:
+                union.fail()
+
+        class A(Base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            if union.union:
+                json1: Mapped[_Json]
+                json2: Mapped[_Json] = mapped_column(nullable=False)
+            elif union.pep604:
+                if not compat.py310:
+                    skip_test("Requires Python 3.10+")
+                json1: Mapped[_JsonPep604]
+                json2: Mapped[_JsonPep604] = mapped_column(nullable=False)
+            else:
+                union.fail()
+
+        is_(A.__table__.c.json1.type._type_affinity, JSON)
+        is_(A.__table__.c.json2.type._type_affinity, JSON)
+        is_true(A.__table__.c.json1.nullable)
+        is_false(A.__table__.c.json2.nullable)
+
     @testing.combinations(
         ("not_optional",),
         ("optional",),