from ..sql.schema import SchemaConst
from ..util.typing import de_optionalize_union_types
from ..util.typing import de_stringify_annotation
+from ..util.typing import de_stringify_union_elements
from ..util.typing import is_fwd_ref
from ..util.typing import is_optional_union
from ..util.typing import is_pep593
+from ..util.typing import is_union
from ..util.typing import Self
from ..util.typing import typing_get_args
if is_fwd_ref(argument):
argument = de_stringify_annotation(cls, argument)
+ if is_union(argument):
+ argument = de_stringify_union_elements(cls, argument)
+
nullable = is_optional_union(argument)
if not self._has_nullable:
checks = (our_type,)
for check_type in checks:
+
if registry.type_annotation_map:
new_sqltype = registry.type_annotation_map.get(check_type)
if new_sqltype is None:
return annotation # type: ignore
+def de_stringify_union_elements(
+ cls: Type[Any],
+ annotation: _AnnotationScanType,
+ str_cleanup_fn: Optional[Callable[[str], str]] = None,
+) -> Type[Any]:
+ return make_union_type(
+ *[
+ de_stringify_annotation(cls, anno, str_cleanup_fn)
+ for anno in annotation.__args__ # type: ignore
+ ]
+ )
+
+
def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
return type_ is not None and typing_get_origin(type_) is Annotated
return (type_,)
-def is_optional(type_):
+def is_optional(type_: Any) -> bool:
return is_origin_of(
type_,
"Optional",
return is_optional(type_) and NoneType in typing_get_args(type_)
-def is_union(type_):
+def is_union(type_: Any) -> bool:
return is_origin_of(type_, "Union")
from __future__ import annotations
+from decimal import Decimal
from typing import List
+from typing import Optional
from typing import Set
from typing import TypeVar
+from typing import Union
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
+from sqlalchemy import Numeric
+from sqlalchemy import Table
from sqlalchemy.orm import attribute_mapped_collection
+from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import MappedCollection
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
-from .test_typed_mapping import MappedColumnTest # noqa
+from sqlalchemy.util import compat
+from .test_typed_mapping import MappedColumnTest as _MappedColumnTest
from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest
"""runs the annotation-sensitive tests from test_typed_mappings while
_R = TypeVar("_R")
+class MappedColumnTest(_MappedColumnTest):
+ def test_unions(self):
+ our_type = Numeric(10, 2)
+
+ class Base(DeclarativeBase):
+ type_annotation_map = {Union[float, Decimal]: our_type}
+
+ class User(Base):
+ __tablename__ = "users"
+ __table__: Table
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ data: Mapped[Union[float, Decimal]] = mapped_column()
+ reverse_data: Mapped[Union[Decimal, float]] = mapped_column()
+
+ optional_data: Mapped[
+ Optional[Union[float, Decimal]]
+ ] = mapped_column()
+
+ # use Optional directly
+ reverse_optional_data: Mapped[
+ Optional[Union[Decimal, float]]
+ ] = mapped_column()
+
+ # use Union with None, same as Optional but presents differently
+ # (Optional object with __origin__ Union vs. Union)
+ reverse_u_optional_data: Mapped[
+ Union[Decimal, float, None]
+ ] = mapped_column()
+
+ float_data: Mapped[float] = mapped_column()
+ decimal_data: Mapped[Decimal] = mapped_column()
+
+ if compat.py310:
+ pep604_data: Mapped[float | Decimal] = mapped_column()
+ pep604_reverse: Mapped[Decimal | float] = mapped_column()
+ pep604_optional: Mapped[
+ Decimal | float | None
+ ] = mapped_column()
+ pep604_data_fwd: Mapped["float | Decimal"] = mapped_column()
+ pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column()
+ pep604_optional_fwd: Mapped[
+ "Decimal | float | None"
+ ] = mapped_column()
+
+ is_(User.__table__.c.data.type, our_type)
+ is_false(User.__table__.c.data.nullable)
+ is_(User.__table__.c.reverse_data.type, our_type)
+ is_(User.__table__.c.optional_data.type, our_type)
+ is_true(User.__table__.c.optional_data.nullable)
+
+ is_(User.__table__.c.reverse_optional_data.type, our_type)
+ is_(User.__table__.c.reverse_u_optional_data.type, our_type)
+ is_true(User.__table__.c.reverse_optional_data.nullable)
+ is_true(User.__table__.c.reverse_u_optional_data.nullable)
+
+ is_(User.__table__.c.float_data.type, our_type)
+ is_(User.__table__.c.decimal_data.type, our_type)
+
+ if compat.py310:
+ for suffix in ("", "_fwd"):
+ data_col = User.__table__.c[f"pep604_data{suffix}"]
+ reverse_col = User.__table__.c[f"pep604_reverse{suffix}"]
+ optional_col = User.__table__.c[f"pep604_optional{suffix}"]
+ is_(data_col.type, our_type)
+ is_false(data_col.nullable)
+ is_(reverse_col.type, our_type)
+ is_false(reverse_col.nullable)
+ is_(optional_col.type, our_type)
+ is_true(optional_col.nullable)
+
+
class MappedOneArg(MappedCollection[str, _R]):
pass