from ..util.typing import de_optionalize_union_types
from ..util.typing import de_stringify_annotation
from ..util.typing import is_fwd_ref
+from ..util.typing import is_optional_union
from ..util.typing import is_pep593
-from ..util.typing import NoneType
from ..util.typing import Self
from ..util.typing import typing_get_args
) -> None:
sqltype = self.column.type
- nullable = False
+ if is_fwd_ref(argument):
+ argument = de_stringify_annotation(cls, argument)
- if hasattr(argument, "__origin__"):
- nullable = NoneType in argument.__args__ # type: ignore
+ nullable = is_optional_union(argument)
if not self._has_nullable:
self.column.nullable = nullable
our_type = de_optionalize_union_types(argument)
- if is_fwd_ref(our_type):
- our_type = de_stringify_annotation(cls, our_type)
use_args_from = None
if is_pep593(our_type):
def expand_unions(
type_: Type[Any], include_union: bool = False, discard_none: bool = False
) -> Tuple[Type[Any], ...]:
- """Return a type as as a tuple of individual types, expanding for
+ """Return a type as a tuple of individual types, expanding for
``Union`` types."""
if is_union(type_):
type_,
"Optional",
"Union",
+ "UnionType",
)
+def is_optional_union(type_: Any) -> bool:
+ return is_optional(type_) and NoneType in typing_get_args(type_)
+
+
def is_union(type_):
return is_origin_of(type_, "Union")
"""return True if the given type has an __origin__ with the given name
and optional module."""
- origin = getattr(type_, "__origin__", None)
+ origin = typing_get_origin(type_)
if origin is None:
return False
from sqlalchemy.testing import is_not
from sqlalchemy.testing import is_true
from sqlalchemy.testing.fixtures import fixture_session
+from sqlalchemy.util import compat
from sqlalchemy.util.typing import Annotated
data: Mapped[Union[float, Decimal]] = mapped_column()
reverse_data: Mapped[Union[Decimal, float]] = mapped_column()
+
optional_data: Mapped[
Optional[Union[float, Decimal]]
] = mapped_column()
reverse_u_optional_data: Mapped[
Union[Decimal, float, None]
] = mapped_column()
+
float_data: Mapped[float] = mapped_column()
decimal_data: Mapped[Decimal] = mapped_column()
+ if compat.py310:
+ pep604_data: Mapped[float | Decimal] = mapped_column()
+ pep604_reverse: Mapped[Decimal | float] = mapped_column()
+ pep604_optional: Mapped[
+ Decimal | float | None
+ ] = mapped_column()
+ pep604_data_fwd: Mapped["float | Decimal"] = mapped_column()
+ pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column()
+ pep604_optional_fwd: Mapped[
+ "Decimal | float | None"
+ ] = mapped_column()
+
is_(User.__table__.c.data.type, our_type)
is_false(User.__table__.c.data.nullable)
is_(User.__table__.c.reverse_data.type, our_type)
is_(User.__table__.c.float_data.type, our_type)
is_(User.__table__.c.decimal_data.type, our_type)
+ if compat.py310:
+ for suffix in ("", "_fwd"):
+ data_col = User.__table__.c[f"pep604_data{suffix}"]
+ reverse_col = User.__table__.c[f"pep604_reverse{suffix}"]
+ optional_col = User.__table__.c[f"pep604_optional{suffix}"]
+ is_(data_col.type, our_type)
+ is_false(data_col.nullable)
+ is_(reverse_col.type, our_type)
+ is_false(reverse_col.nullable)
+ is_(optional_col.type, our_type)
+ is_true(optional_col.nullable)
+
def test_missing_mapped_lhs(self, decl_base):
with expect_raises_message(
ArgumentError,