from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
+import weakref
from . import attributes
from . import util as orm_util
from .. import util
from ..sql import expression
from ..sql.elements import BindParameter
-from ..util.typing import Protocol
if typing.TYPE_CHECKING:
from ._typing import _InstanceDict
_PT = TypeVar("_PT", bound=Any)
-class _CompositeClassProto(Protocol):
- def __init__(self, *args: Any):
- ...
-
- def __composite_values__(self) -> Tuple[Any, ...]:
- ...
-
-
class DescriptorProperty(MapperProperty[_T]):
""":class:`.MapperProperty` which proxies access to a
user-defined descriptor."""
]
-_CC = TypeVar("_CC", bound=_CompositeClassProto)
+_CC = TypeVar("_CC", bound=Any)
+
+
+_composite_getters: weakref.WeakKeyDictionary[
+ Type[Any], Callable[[Any], Tuple[Any, ...]]
+] = weakref.WeakKeyDictionary()
class Composite(
util.set_creation_order(self)
self._create_descriptor()
+ self._init_accessor()
def instrument_class(self, mapper: Mapper[Any]) -> None:
super().instrument_class(mapper)
" method; can't get state"
) from ae
else:
- return accessor()
+ return accessor() # type: ignore
def do_init(self) -> None:
"""Initialization which occurs after the :class:`.Composite`
extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
+ MappedColumn = util.preloaded.orm_properties.MappedColumn
if (
self.composite_class is None
and extracted_mapped_annotation is None
if isinstance(argument, str) or hasattr(
argument, "__forward_arg__"
):
+ str_arg = (
+ argument.__forward_arg__
+ if hasattr(argument, "__forward_arg__")
+ else str(argument)
+ )
raise sa_exc.ArgumentError(
f"Can't use forward ref {argument} for composite "
- f"class argument"
+ f"class argument; set up the type as Mapped[{str_arg}]"
)
self.composite_class = argument
if is_dataclass(self.composite_class):
self._setup_for_dataclass(registry, cls, key)
+ else:
+ for attr in self.attrs:
+ if (
+ isinstance(attr, (MappedColumn, schema.Column))
+ and attr.name is None
+ ):
+ raise sa_exc.ArgumentError(
+ "Composite class column arguments must be named "
+ "unless a dataclass is used"
+ )
+ self._init_accessor()
+
+ def _init_accessor(self) -> None:
+ if is_dataclass(self.composite_class) and not hasattr(
+ self.composite_class, "__composite_values__"
+ ):
+ insp = inspect.signature(self.composite_class)
+ getter = operator.attrgetter(
+ *[p.name for p in insp.parameters.values()]
+ )
+ if len(insp.parameters) == 1:
+ self._generated_composite_accessor = lambda obj: (getter(obj),)
+ else:
+ self._generated_composite_accessor = getter
+
+ if (
+ self.composite_class is not None
+ and isinstance(self.composite_class, type)
+ and self.composite_class not in _composite_getters
+ ):
+ if self._generated_composite_accessor is not None:
+ _composite_getters[
+ self.composite_class
+ ] = self._generated_composite_accessor
+ elif hasattr(self.composite_class, "__composite_values__"):
+ _composite_getters[
+ self.composite_class
+ ] = lambda obj: obj.__composite_values__() # type: ignore
@util.preload_module("sqlalchemy.orm.properties")
@util.preload_module("sqlalchemy.orm.decl_base")
elif isinstance(attr, schema.Column):
decl_base._undefer_column_name(param.name, attr)
- if not hasattr(self.composite_class, "__composite_values__"):
- getter = operator.attrgetter(
- *[p.name for p in insp.parameters.values()]
- )
- if len(insp.parameters) == 1:
- self._generated_composite_accessor = lambda obj: (getter(obj),)
- else:
- self._generated_composite_accessor = getter
-
@util.memoized_property
def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]:
return [getattr(self.parent.class_, prop.key) for prop in self.props]
--- /dev/null
+import dataclasses
+
+from sqlalchemy import select
+from sqlalchemy.orm import composite
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+@dataclasses.dataclass
+class Point:
+ def __init__(self, x: int, y: int):
+ self.x = x
+ self.y = y
+
+
+class Vertex(Base):
+ __tablename__ = "vertices"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ x1: Mapped[int]
+ y1: Mapped[int]
+ x2: Mapped[int]
+ y2: Mapped[int]
+
+ # inferred from right hand side
+ start = composite(Point, "x1", "y1")
+
+ # taken from left hand side
+ end: Mapped[Point] = composite(Point, "x2", "y2")
+
+
+v1 = Vertex(start=Point(3, 4), end=Point(5, 6))
+
+stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)]))
+
+# EXPECTED_TYPE: Select[Tuple[Vertex]]
+reveal_type(stmt)
+
+# EXPECTED_TYPE: composite.Point
+reveal_type(v1.start)
+
+# EXPECTED_TYPE: composite.Point
+reveal_type(v1.end)
+
+# EXPECTED_TYPE: int
+reveal_type(v1.end.y)