from ..sql.util import join_condition
from ..sql.util import selectables_overlap
from ..sql.util import visit_binary_product
+from ..util.typing import de_optionalize_union_types
from ..util.typing import Literal
if typing.TYPE_CHECKING:
self.lazy = "dynamic"
self.strategy_key = (("lazy", self.lazy),)
- if hasattr(argument, "__origin__"):
+ argument = de_optionalize_union_types(argument)
- collection_class = argument.__origin__ # type: ignore
- if issubclass(collection_class, abc.Collection):
+ if hasattr(argument, "__origin__"):
+ arg_origin = argument.__origin__ # type: ignore
+ if isinstance(arg_origin, type) and issubclass(
+ arg_origin, abc.Collection
+ ):
if self.collection_class is None:
- self.collection_class = collection_class
+ self.collection_class = arg_origin
elif not is_write_only and not is_dynamic:
self.uselist = False
if argument.__args__: # type: ignore
- if issubclass(
- argument.__origin__, typing.Mapping # type: ignore
+ if isinstance(arg_origin, type) and issubclass(
+ arg_origin, typing.Mapping # type: ignore
):
type_arg = argument.__args__[-1] # type: ignore
else:
id: Mapped[intpk] = mapped_column(init=False)
email_address: Mapped[str]
user_id: Mapped[user_fk] = mapped_column(init=False)
- user: Mapped["User"] = relationship(
+ user: Mapped[Optional["User"]] = relationship(
back_populates="addresses", default=None
)
select(A).join(A.bs), "SELECT a.id FROM a JOIN b ON a.id = b.a_id"
)
- def test_basic_bidirectional(self, decl_base):
+ @testing.combinations(True, False, argnames="optional_on_m2o")
+ def test_basic_bidirectional(self, decl_base, optional_on_m2o):
class A(decl_base):
__tablename__ = "a"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
- a: Mapped["A"] = relationship(
- back_populates="bs", primaryjoin=a_id == A.id
- )
+ if optional_on_m2o:
+ a: Mapped[Optional["A"]] = relationship(
+ back_populates="bs", primaryjoin=a_id == A.id
+ )
+ else:
+ a: Mapped["A"] = relationship(
+ back_populates="bs", primaryjoin=a_id == A.id
+ )
a1 = A(data="data")
b1 = B()