--- /dev/null
+.. change::
+ :tags: bug, orm
+ :tickets: 11370
+
+ Fixed issue regarding ``Union`` types that would be present in the
+ :paramref:`_orm.registry.type_annotation_map` of a :class:`_orm.registry`
+ or declarative base class, where a ``Mapped[]`` element that included one
+ of the subtypes present in that ``Union`` would be matched to that entry,
+ potentially ignoring other entries that matched exactly. The correct
+ behavior now takes place such that an entry should only match in
+ ``type_annotation_map`` exactly, as a ``Union`` type is a self-contained
+ type. For example, an attribute with ``Mapped[float]`` would previously
+ match to a ``type_annotation_map`` entry ``Union[float, Decimal]``; this
+ will no longer match and will now only match to an entry that states
+ ``float``. Pull request courtesy Frazer McLean.
from ..util import hybridproperty
from ..util import typing as compat_typing
from ..util.typing import CallableReference
+from ..util.typing import de_optionalize_union_types
from ..util.typing import flatten_newtype
from ..util.typing import is_generic
from ..util.typing import is_literal
self.type_annotation_map.update(
{
- sub_type: sqltype
+ de_optionalize_union_types(typ): sqltype
for typ, sqltype in type_annotation_map.items()
- for sub_type in compat_typing.expand_unions(
- typ, include_union=True, discard_none=True
- )
}
)
"""
+ while is_pep695(type_):
+ type_ = type_.__value__
+
if is_fwd_ref(type_):
return de_optionalize_fwd_ref_union_types(type_)
return cast(Any, Union).__getitem__(types) # type: ignore
-def expand_unions(
- type_: Type[Any], include_union: bool = False, discard_none: bool = False
-) -> Tuple[Type[Any], ...]:
- """Return a type as a tuple of individual types, expanding for
- ``Union`` types."""
-
- if is_union(type_):
- typ = set(type_.__args__)
-
- if discard_none:
- typ.discard(NoneType)
-
- if include_union:
- return (type_,) + tuple(typ) # type: ignore
- else:
- return tuple(typ) # type: ignore
- else:
- return (type_,)
-
-
def is_optional(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
return is_origin_of(
type_,
def is_union(type_: Any) -> TypeGuard[ArgsTypeProcotol]:
- return is_origin_of(type_, "Union")
+ return is_origin_of(type_, "Union", "UnionType")
def is_origin_of_cls(
from pathlib import Path
import pickle
import sys
+import typing
+
+import typing_extensions
from sqlalchemy import exc
from sqlalchemy import sql
from sqlalchemy.util._collections import merge_lists_w_ordering
from sqlalchemy.util._has_cy import _import_cy_extensions
from sqlalchemy.util._has_cy import HAS_CYEXTENSION
+from sqlalchemy.util.typing import is_union
class WeakSequenceTest(fixtures.TestBase):
for f in cython_files
}
eq_({m.__name__ for m in ext}, set(names))
+
+
+class TypingTest(fixtures.TestBase):
+ def test_is_union(self):
+ assert is_union(typing.Union[str, int])
+ assert is_union(typing_extensions.Union[str, int])
+ if compat.py310:
+ assert is_union(str | int)
from typing_extensions import get_args as get_args
from typing_extensions import Literal as Literal
from typing_extensions import TypeAlias as TypeAlias
+from typing_extensions import TypeAliasType
from typing_extensions import TypedDict
from sqlalchemy import BIGINT
from sqlalchemy import DateTime
from sqlalchemy import exc
from sqlalchemy import exc as sa_exc
+from sqlalchemy import Float
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Identity
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_not
from sqlalchemy.testing import is_true
+from sqlalchemy.testing import skip_test
from sqlalchemy.testing import Variation
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.util import compat
_TypingLiteral = typing.Literal["a", "b"]
_TypingExtensionsLiteral = typing_extensions.Literal["a", "b"]
+_JsonPrimitive: TypeAlias = Union[str, int, float, bool, None]
+_JsonObject: TypeAlias = Dict[str, "_Json"]
+_JsonArray: TypeAlias = List["_Json"]
+_Json: TypeAlias = Union[_JsonObject, _JsonArray, _JsonPrimitive]
+
+if compat.py310:
+ _JsonPrimitivePep604: TypeAlias = str | int | float | bool | None
+ _JsonObjectPep604: TypeAlias = dict[str, "_JsonPep604"]
+ _JsonArrayPep604: TypeAlias = list["_JsonPep604"]
+ _JsonPep604: TypeAlias = (
+ _JsonObjectPep604 | _JsonArrayPep604 | _JsonPrimitivePep604
+ )
+
if compat.py312:
exec(
"""
else:
is_(getattr(Element.__table__.c.data, paramname), override_value)
- def test_unions(self):
+ @testing.variation("union", ["union", "pep604"])
+ @testing.variation("typealias", ["legacy", "pep695"])
+ def test_unions(self, union, typealias):
our_type = Numeric(10, 2)
+ if union.union:
+ UnionType = Union[float, Decimal]
+ elif union.pep604:
+ if not compat.py310:
+ skip_test("Required Python 3.10")
+ UnionType = float | Decimal
+ else:
+ union.fail()
+
+ if typealias.legacy:
+ UnionTypeAlias = UnionType
+ elif typealias.pep695:
+ # same as type UnionTypeAlias = UnionType
+ UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType)
+ else:
+ typealias.fail()
+
class Base(DeclarativeBase):
- type_annotation_map = {Union[float, Decimal]: our_type}
+ type_annotation_map = {UnionTypeAlias: our_type}
class User(Base):
__tablename__ = "users"
mapped_column()
)
+ if compat.py312:
+ MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal)
+ pep695_data: Mapped[MyTypeAlias] = mapped_column()
+
is_(User.__table__.c.data.type, our_type)
is_false(User.__table__.c.data.nullable)
is_(User.__table__.c.reverse_data.type, our_type)
is_true(User.__table__.c.reverse_optional_data.nullable)
is_true(User.__table__.c.reverse_u_optional_data.nullable)
- is_(User.__table__.c.float_data.type, our_type)
- is_(User.__table__.c.decimal_data.type, our_type)
+ is_true(isinstance(User.__table__.c.float_data.type, Float))
+ is_true(isinstance(User.__table__.c.float_data.type, Numeric))
+ is_not(User.__table__.c.decimal_data.type, our_type)
if compat.py310:
for suffix in ("", "_fwd"):
is_(optional_col.type, our_type)
is_true(optional_col.nullable)
+ if compat.py312:
+ is_(User.__table__.c.pep695_data.type, our_type)
+
+ @testing.variation("union", ["union", "pep604"])
+ def test_optional_in_annotation_map(self, union):
+ """SQLAlchemy's behaviour is clear: an optional type means the column
+ is inferred as nullable. Some types which a user may want to put in the
+ type annotation map are already optional. JSON is a good example
+ because without any constraint, the type can be None via JSON null or
+ SQL NULL.
+
+ By permitting optional types in the type annotation map, everything
+ just works, and mapped_column(nullable=False) is available if desired.
+
+ See issue #11370
+ """
+
+ class Base(DeclarativeBase):
+ if union.union:
+ type_annotation_map = {
+ _Json: JSON,
+ }
+ elif union.pep604:
+ if not compat.py310:
+ skip_test("Requires Python 3.10+")
+ type_annotation_map = {
+ _JsonPep604: JSON,
+ }
+ else:
+ union.fail()
+
+ class A(Base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ if union.union:
+ json1: Mapped[_Json]
+ json2: Mapped[_Json] = mapped_column(nullable=False)
+ elif union.pep604:
+ if not compat.py310:
+ skip_test("Requires Python 3.10+")
+ json1: Mapped[_JsonPep604]
+ json2: Mapped[_JsonPep604] = mapped_column(nullable=False)
+ else:
+ union.fail()
+
+ is_(A.__table__.c.json1.type._type_affinity, JSON)
+ is_(A.__table__.c.json2.type._type_affinity, JSON)
+ is_true(A.__table__.c.json1.nullable)
+ is_false(A.__table__.c.json2.nullable)
+
@testing.combinations(
("not_optional",),
("optional",),
from typing_extensions import get_args as get_args
from typing_extensions import Literal as Literal
from typing_extensions import TypeAlias as TypeAlias
+from typing_extensions import TypeAliasType
from typing_extensions import TypedDict
from sqlalchemy import BIGINT
from sqlalchemy import DateTime
from sqlalchemy import exc
from sqlalchemy import exc as sa_exc
+from sqlalchemy import Float
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Identity
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_not
from sqlalchemy.testing import is_true
+from sqlalchemy.testing import skip_test
from sqlalchemy.testing import Variation
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.util import compat
_TypingLiteral = typing.Literal["a", "b"]
_TypingExtensionsLiteral = typing_extensions.Literal["a", "b"]
+_JsonPrimitive: TypeAlias = Union[str, int, float, bool, None]
+_JsonObject: TypeAlias = Dict[str, "_Json"]
+_JsonArray: TypeAlias = List["_Json"]
+_Json: TypeAlias = Union[_JsonObject, _JsonArray, _JsonPrimitive]
+
+if compat.py310:
+ _JsonPrimitivePep604: TypeAlias = str | int | float | bool | None
+ _JsonObjectPep604: TypeAlias = dict[str, "_JsonPep604"]
+ _JsonArrayPep604: TypeAlias = list["_JsonPep604"]
+ _JsonPep604: TypeAlias = (
+ _JsonObjectPep604 | _JsonArrayPep604 | _JsonPrimitivePep604
+ )
+
if compat.py312:
exec(
"""
else:
is_(getattr(Element.__table__.c.data, paramname), override_value)
- def test_unions(self):
+ @testing.variation("union", ["union", "pep604"])
+ @testing.variation("typealias", ["legacy", "pep695"])
+ def test_unions(self, union, typealias):
our_type = Numeric(10, 2)
+ if union.union:
+ UnionType = Union[float, Decimal]
+ elif union.pep604:
+ if not compat.py310:
+ skip_test("Required Python 3.10")
+ UnionType = float | Decimal
+ else:
+ union.fail()
+
+ if typealias.legacy:
+ UnionTypeAlias = UnionType
+ elif typealias.pep695:
+ # same as type UnionTypeAlias = UnionType
+ UnionTypeAlias = TypeAliasType("UnionTypeAlias", UnionType)
+ else:
+ typealias.fail()
+
class Base(DeclarativeBase):
- type_annotation_map = {Union[float, Decimal]: our_type}
+ type_annotation_map = {UnionTypeAlias: our_type}
class User(Base):
__tablename__ = "users"
mapped_column()
)
+ if compat.py312:
+ MyTypeAlias = TypeAliasType("MyTypeAlias", float | Decimal)
+ pep695_data: Mapped[MyTypeAlias] = mapped_column()
+
is_(User.__table__.c.data.type, our_type)
is_false(User.__table__.c.data.nullable)
is_(User.__table__.c.reverse_data.type, our_type)
is_true(User.__table__.c.reverse_optional_data.nullable)
is_true(User.__table__.c.reverse_u_optional_data.nullable)
- is_(User.__table__.c.float_data.type, our_type)
- is_(User.__table__.c.decimal_data.type, our_type)
+ is_true(isinstance(User.__table__.c.float_data.type, Float))
+ is_true(isinstance(User.__table__.c.float_data.type, Numeric))
+ is_not(User.__table__.c.decimal_data.type, our_type)
if compat.py310:
for suffix in ("", "_fwd"):
is_(optional_col.type, our_type)
is_true(optional_col.nullable)
+ if compat.py312:
+ is_(User.__table__.c.pep695_data.type, our_type)
+
+ @testing.variation("union", ["union", "pep604"])
+ def test_optional_in_annotation_map(self, union):
+ """SQLAlchemy's behaviour is clear: an optional type means the column
+ is inferred as nullable. Some types which a user may want to put in the
+ type annotation map are already optional. JSON is a good example
+ because without any constraint, the type can be None via JSON null or
+ SQL NULL.
+
+ By permitting optional types in the type annotation map, everything
+ just works, and mapped_column(nullable=False) is available if desired.
+
+ See issue #11370
+ """
+
+ class Base(DeclarativeBase):
+ if union.union:
+ type_annotation_map = {
+ _Json: JSON,
+ }
+ elif union.pep604:
+ if not compat.py310:
+ skip_test("Requires Python 3.10+")
+ type_annotation_map = {
+ _JsonPep604: JSON,
+ }
+ else:
+ union.fail()
+
+ class A(Base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ if union.union:
+ json1: Mapped[_Json]
+ json2: Mapped[_Json] = mapped_column(nullable=False)
+ elif union.pep604:
+ if not compat.py310:
+ skip_test("Requires Python 3.10+")
+ json1: Mapped[_JsonPep604]
+ json2: Mapped[_JsonPep604] = mapped_column(nullable=False)
+ else:
+ union.fail()
+
+ is_(A.__table__.c.json1.type._type_affinity, JSON)
+ is_(A.__table__.c.json2.type._type_affinity, JSON)
+ is_true(A.__table__.c.json1.nullable)
+ is_false(A.__table__.c.json2.nullable)
+
@testing.combinations(
("not_optional",),
("optional",),