From 54f5def028d8f46ead37e8046d2aea3bb9953ebc Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 14 Jun 2022 09:31:09 -0400 Subject: [PATCH] typing adjustments for composites * if dataclass isn't used, columns have to be named * _CompositeClassProto is not useful as dataclasses have no methods / bases we can use, so composite is against Any * Adjust session.get() feature to work w/ dataclass composites Change-Id: Icc606cc76871c738dc794ea4555fca8a1ab0e0fd --- lib/sqlalchemy/orm/_typing.py | 8 ++- lib/sqlalchemy/orm/descriptor_props.py | 75 ++++++++++++++++------ lib/sqlalchemy/orm/properties.py | 4 ++ lib/sqlalchemy/orm/session.py | 12 +++- lib/sqlalchemy/sql/sqltypes.py | 2 +- test/ext/mypy/plain_files/composite_dc.py | 51 +++++++++++++++ test/orm/declarative/test_typed_mapping.py | 25 ++++++++ test/orm/test_composites.py | 29 +++++++++ 8 files changed, 179 insertions(+), 27 deletions(-) create mode 100644 test/ext/mypy/plain_files/composite_dc.py diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py index 0e624afe2a..ed04c96c7c 100644 --- a/lib/sqlalchemy/orm/_typing.py +++ b/lib/sqlalchemy/orm/_typing.py @@ -24,7 +24,6 @@ if TYPE_CHECKING: from .attributes import QueryableAttribute from .base import PassiveFlag from .decl_api import registry as _registry_type - from .descriptor_props import _CompositeClassProto from .interfaces import InspectionAttr from .interfaces import MapperProperty from .interfaces import UserDefinedOption @@ -103,8 +102,11 @@ def is_user_defined_option( return not opt._is_core and opt._is_user_defined # type: ignore -def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]: - return hasattr(obj, "__composite_values__") +def is_composite_class(obj: Any) -> bool: + # inlining is_dataclass(obj) + return hasattr(obj, "__composite_values__") or hasattr( + obj, "__dataclass_fields__" + ) if TYPE_CHECKING: diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index a366a9534f..d67319700a 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -28,6 +28,7 @@ from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +import weakref from . import attributes from . import util as orm_util @@ -48,7 +49,6 @@ from .. import sql from .. import util from ..sql import expression from ..sql.elements import BindParameter -from ..util.typing import Protocol if typing.TYPE_CHECKING: from ._typing import _InstanceDict @@ -78,14 +78,6 @@ _T = TypeVar("_T", bound=Any) _PT = TypeVar("_PT", bound=Any) -class _CompositeClassProto(Protocol): - def __init__(self, *args: Any): - ... - - def __composite_values__(self) -> Tuple[Any, ...]: - ... - - class DescriptorProperty(MapperProperty[_T]): """:class:`.MapperProperty` which proxies access to a user-defined descriptor.""" @@ -167,7 +159,12 @@ _CompositeAttrType = Union[ ] -_CC = TypeVar("_CC", bound=_CompositeClassProto) +_CC = TypeVar("_CC", bound=Any) + + +_composite_getters: weakref.WeakKeyDictionary[ + Type[Any], Callable[[Any], Tuple[Any, ...]] +] = weakref.WeakKeyDictionary() class Composite( @@ -236,6 +233,7 @@ class Composite( util.set_creation_order(self) self._create_descriptor() + self._init_accessor() def instrument_class(self, mapper: Mapper[Any]) -> None: super().instrument_class(mapper) @@ -254,7 +252,7 @@ class Composite( " method; can't get state" ) from ae else: - return accessor() + return accessor() # type: ignore def do_init(self) -> None: """Initialization which occurs after the :class:`.Composite` @@ -337,6 +335,7 @@ class Composite( extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: + MappedColumn = util.preloaded.orm_properties.MappedColumn if ( self.composite_class is None and extracted_mapped_annotation is None @@ -347,14 +346,57 @@ class Composite( if isinstance(argument, str) or hasattr( argument, "__forward_arg__" ): + str_arg = ( + argument.__forward_arg__ + if hasattr(argument, "__forward_arg__") + else str(argument) + ) raise sa_exc.ArgumentError( f"Can't use forward ref {argument} for composite " - f"class argument" + f"class argument; set up the type as Mapped[{str_arg}]" ) self.composite_class = argument if is_dataclass(self.composite_class): self._setup_for_dataclass(registry, cls, key) + else: + for attr in self.attrs: + if ( + isinstance(attr, (MappedColumn, schema.Column)) + and attr.name is None + ): + raise sa_exc.ArgumentError( + "Composite class column arguments must be named " + "unless a dataclass is used" + ) + self._init_accessor() + + def _init_accessor(self) -> None: + if is_dataclass(self.composite_class) and not hasattr( + self.composite_class, "__composite_values__" + ): + insp = inspect.signature(self.composite_class) + getter = operator.attrgetter( + *[p.name for p in insp.parameters.values()] + ) + if len(insp.parameters) == 1: + self._generated_composite_accessor = lambda obj: (getter(obj),) + else: + self._generated_composite_accessor = getter + + if ( + self.composite_class is not None + and isinstance(self.composite_class, type) + and self.composite_class not in _composite_getters + ): + if self._generated_composite_accessor is not None: + _composite_getters[ + self.composite_class + ] = self._generated_composite_accessor + elif hasattr(self.composite_class, "__composite_values__"): + _composite_getters[ + self.composite_class + ] = lambda obj: obj.__composite_values__() # type: ignore @util.preload_module("sqlalchemy.orm.properties") @util.preload_module("sqlalchemy.orm.decl_base") @@ -388,15 +430,6 @@ class Composite( elif isinstance(attr, schema.Column): decl_base._undefer_column_name(param.name, attr) - if not hasattr(self.composite_class, "__composite_values__"): - getter = operator.attrgetter( - *[p.name for p in insp.parameters.values()] - ) - if len(insp.parameters) == 1: - self._generated_composite_accessor = lambda obj: (getter(obj),) - else: - self._generated_composite_accessor = getter - @util.memoized_property def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: return [getattr(self.parent.class_, prop.key) for prop in self.props] diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index d77d6e63c0..0644222936 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -536,6 +536,10 @@ class MappedColumn( util.set_creation_order(new) return new + @property + def name(self) -> str: + return self.column.name + @property def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]: if self.deferred: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 788821b987..ec6f41b286 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -35,6 +35,7 @@ import weakref from . import attributes from . import context +from . import descriptor_props from . import exc from . import identity from . import loading @@ -3193,8 +3194,15 @@ class Session(_SessionClassMethods, EventTarget): ) -> Optional[_O]: # convert composite types to individual args - if is_composite_class(primary_key_identity): - primary_key_identity = primary_key_identity.__composite_values__() + if ( + is_composite_class(primary_key_identity) + and type(primary_key_identity) + in descriptor_props._composite_getters + ): + getter = descriptor_props._composite_getters[ + type(primary_key_identity) + ] + primary_key_identity = getter(primary_key_identity) mapper: Optional[Mapper[_O]] = inspect(entity) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index faa0c794cc..32f0813f5d 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -3359,7 +3359,7 @@ class Uuid(TypeEngine[_UUID_RETURN]): __visit_name__ = "uuid" - collation = None + collation: Optional[str] = None @overload def __init__( diff --git a/test/ext/mypy/plain_files/composite_dc.py b/test/ext/mypy/plain_files/composite_dc.py new file mode 100644 index 0000000000..fa1b16a2a6 --- /dev/null +++ b/test/ext/mypy/plain_files/composite_dc.py @@ -0,0 +1,51 @@ +import dataclasses + +from sqlalchemy import select +from sqlalchemy.orm import composite +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +@dataclasses.dataclass +class Point: + def __init__(self, x: int, y: int): + self.x = x + self.y = y + + +class Vertex(Base): + __tablename__ = "vertices" + + id: Mapped[int] = mapped_column(primary_key=True) + x1: Mapped[int] + y1: Mapped[int] + x2: Mapped[int] + y2: Mapped[int] + + # inferred from right hand side + start = composite(Point, "x1", "y1") + + # taken from left hand side + end: Mapped[Point] = composite(Point, "x2", "y2") + + +v1 = Vertex(start=Point(3, 4), end=Point(5, 6)) + +stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)])) + +# EXPECTED_TYPE: Select[Tuple[Vertex]] +reveal_type(stmt) + +# EXPECTED_TYPE: composite.Point +reveal_type(v1.start) + +# EXPECTED_TYPE: composite.Point +reveal_type(v1.end) + +# EXPECTED_TYPE: int +reveal_type(v1.end.y) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 01849a8ee9..afaa099b21 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1088,6 +1088,31 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): mapped_column(), mapped_column(), mapped_column("zip") ) + def test_cls_not_composite_compliant(self, decl_base): + class Address: + def __init__(self, street: int, state: str, zip_: str): + pass + + street: str + state: str + zip_: str + + with expect_raises_message( + ArgumentError, + r"Composite class column arguments must be " + r"named unless a dataclass is used", + ): + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite( + mapped_column(), mapped_column(), mapped_column("zip") + ) + def test_fwd_ref_ok_explicit_cls(self, decl_base): @dataclasses.dataclass class Address: diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index b8d9d90080..9f3c52325d 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -1,3 +1,5 @@ +import dataclasses + import sqlalchemy as sa from sqlalchemy import ForeignKey from sqlalchemy import Integer @@ -745,6 +747,33 @@ class PrimaryKeyTest(fixtures.MappedTest): eq_(g.version, g2.version) +class PrimaryKeyTestDataclasses(PrimaryKeyTest): + @classmethod + def setup_mappers(cls): + graphs = cls.tables.graphs + + @dataclasses.dataclass + class Version: + id: int + version: int + + cls.classes.Version = Version + + class Graph(cls.Comparable): + def __init__(self, version): + self.version = version + + cls.mapper_registry.map_imperatively( + Graph, + graphs, + properties={ + "version": sa.orm.composite( + Version, graphs.c.id, graphs.c.version_id + ) + }, + ) + + class DefaultsTest(fixtures.MappedTest): @classmethod def define_tables(cls, metadata): -- 2.47.2