]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
typing adjustments for composites
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Jun 2022 13:31:09 +0000 (09:31 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Jun 2022 14:58:31 +0000 (10:58 -0400)
* if dataclass isn't used, columns have to be named
* _CompositeClassProto is not useful as dataclasses have no
  methods / bases we can use, so composite is against Any
* Adjust session.get() feature to work w/ dataclass composites

Change-Id: Icc606cc76871c738dc794ea4555fca8a1ab0e0fd

lib/sqlalchemy/orm/_typing.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql/sqltypes.py
test/ext/mypy/plain_files/composite_dc.py [new file with mode: 0644]
test/orm/declarative/test_typed_mapping.py
test/orm/test_composites.py

index 0e624afe2a84a26b1afc9aa5c3a6145989f1e329..ed04c96c7c161904731dc4be54c746b6bd280d31 100644 (file)
@@ -24,7 +24,6 @@ if TYPE_CHECKING:
     from .attributes import QueryableAttribute
     from .base import PassiveFlag
     from .decl_api import registry as _registry_type
-    from .descriptor_props import _CompositeClassProto
     from .interfaces import InspectionAttr
     from .interfaces import MapperProperty
     from .interfaces import UserDefinedOption
@@ -103,8 +102,11 @@ def is_user_defined_option(
     return not opt._is_core and opt._is_user_defined  # type: ignore
 
 
-def is_composite_class(obj: Any) -> TypeGuard[_CompositeClassProto]:
-    return hasattr(obj, "__composite_values__")
+def is_composite_class(obj: Any) -> bool:
+    # inlining is_dataclass(obj)
+    return hasattr(obj, "__composite_values__") or hasattr(
+        obj, "__dataclass_fields__"
+    )
 
 
 if TYPE_CHECKING:
index a366a9534f4c87d7265436b528d71b4eaf1b8b1c..d67319700abe5e37e265097fc7cbbfd001488a76 100644 (file)
@@ -28,6 +28,7 @@ from typing import Type
 from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
+import weakref
 
 from . import attributes
 from . import util as orm_util
@@ -48,7 +49,6 @@ from .. import sql
 from .. import util
 from ..sql import expression
 from ..sql.elements import BindParameter
-from ..util.typing import Protocol
 
 if typing.TYPE_CHECKING:
     from ._typing import _InstanceDict
@@ -78,14 +78,6 @@ _T = TypeVar("_T", bound=Any)
 _PT = TypeVar("_PT", bound=Any)
 
 
-class _CompositeClassProto(Protocol):
-    def __init__(self, *args: Any):
-        ...
-
-    def __composite_values__(self) -> Tuple[Any, ...]:
-        ...
-
-
 class DescriptorProperty(MapperProperty[_T]):
     """:class:`.MapperProperty` which proxies access to a
     user-defined descriptor."""
@@ -167,7 +159,12 @@ _CompositeAttrType = Union[
 ]
 
 
-_CC = TypeVar("_CC", bound=_CompositeClassProto)
+_CC = TypeVar("_CC", bound=Any)
+
+
+_composite_getters: weakref.WeakKeyDictionary[
+    Type[Any], Callable[[Any], Tuple[Any, ...]]
+] = weakref.WeakKeyDictionary()
 
 
 class Composite(
@@ -236,6 +233,7 @@ class Composite(
 
         util.set_creation_order(self)
         self._create_descriptor()
+        self._init_accessor()
 
     def instrument_class(self, mapper: Mapper[Any]) -> None:
         super().instrument_class(mapper)
@@ -254,7 +252,7 @@ class Composite(
                     " method; can't get state"
                 ) from ae
             else:
-                return accessor()
+                return accessor()  # type: ignore
 
     def do_init(self) -> None:
         """Initialization which occurs after the :class:`.Composite`
@@ -337,6 +335,7 @@ class Composite(
         extracted_mapped_annotation: Optional[_AnnotationScanType],
         is_dataclass_field: bool,
     ) -> None:
+        MappedColumn = util.preloaded.orm_properties.MappedColumn
         if (
             self.composite_class is None
             and extracted_mapped_annotation is None
@@ -347,14 +346,57 @@ class Composite(
             if isinstance(argument, str) or hasattr(
                 argument, "__forward_arg__"
             ):
+                str_arg = (
+                    argument.__forward_arg__
+                    if hasattr(argument, "__forward_arg__")
+                    else str(argument)
+                )
                 raise sa_exc.ArgumentError(
                     f"Can't use forward ref {argument} for composite "
-                    f"class argument"
+                    f"class argument; set up the type as Mapped[{str_arg}]"
                 )
             self.composite_class = argument
 
         if is_dataclass(self.composite_class):
             self._setup_for_dataclass(registry, cls, key)
+        else:
+            for attr in self.attrs:
+                if (
+                    isinstance(attr, (MappedColumn, schema.Column))
+                    and attr.name is None
+                ):
+                    raise sa_exc.ArgumentError(
+                        "Composite class column arguments must be named "
+                        "unless a dataclass is used"
+                    )
+        self._init_accessor()
+
+    def _init_accessor(self) -> None:
+        if is_dataclass(self.composite_class) and not hasattr(
+            self.composite_class, "__composite_values__"
+        ):
+            insp = inspect.signature(self.composite_class)
+            getter = operator.attrgetter(
+                *[p.name for p in insp.parameters.values()]
+            )
+            if len(insp.parameters) == 1:
+                self._generated_composite_accessor = lambda obj: (getter(obj),)
+            else:
+                self._generated_composite_accessor = getter
+
+        if (
+            self.composite_class is not None
+            and isinstance(self.composite_class, type)
+            and self.composite_class not in _composite_getters
+        ):
+            if self._generated_composite_accessor is not None:
+                _composite_getters[
+                    self.composite_class
+                ] = self._generated_composite_accessor
+            elif hasattr(self.composite_class, "__composite_values__"):
+                _composite_getters[
+                    self.composite_class
+                ] = lambda obj: obj.__composite_values__()  # type: ignore
 
     @util.preload_module("sqlalchemy.orm.properties")
     @util.preload_module("sqlalchemy.orm.decl_base")
@@ -388,15 +430,6 @@ class Composite(
             elif isinstance(attr, schema.Column):
                 decl_base._undefer_column_name(param.name, attr)
 
-        if not hasattr(self.composite_class, "__composite_values__"):
-            getter = operator.attrgetter(
-                *[p.name for p in insp.parameters.values()]
-            )
-            if len(insp.parameters) == 1:
-                self._generated_composite_accessor = lambda obj: (getter(obj),)
-            else:
-                self._generated_composite_accessor = getter
-
     @util.memoized_property
     def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]:
         return [getattr(self.parent.class_, prop.key) for prop in self.props]
index d77d6e63c0e776b5d7859e16ab7bca75cdfc69c8..064422293628c447adceb4cea691aab913e4f35d 100644 (file)
@@ -536,6 +536,10 @@ class MappedColumn(
         util.set_creation_order(new)
         return new
 
+    @property
+    def name(self) -> str:
+        return self.column.name
+
     @property
     def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
         if self.deferred:
index 788821b98737932a366098e8c8b0194cd0eafae3..ec6f41b286f287d05458760d74cbccc8842ee090 100644 (file)
@@ -35,6 +35,7 @@ import weakref
 
 from . import attributes
 from . import context
+from . import descriptor_props
 from . import exc
 from . import identity
 from . import loading
@@ -3193,8 +3194,15 @@ class Session(_SessionClassMethods, EventTarget):
     ) -> Optional[_O]:
 
         # convert composite types to individual args
-        if is_composite_class(primary_key_identity):
-            primary_key_identity = primary_key_identity.__composite_values__()
+        if (
+            is_composite_class(primary_key_identity)
+            and type(primary_key_identity)
+            in descriptor_props._composite_getters
+        ):
+            getter = descriptor_props._composite_getters[
+                type(primary_key_identity)
+            ]
+            primary_key_identity = getter(primary_key_identity)
 
         mapper: Optional[Mapper[_O]] = inspect(entity)
 
index faa0c794cc97b65f726b80c57f52de0533d288cd..32f0813f5d9e95e50e5b28cac7e4235d97687f22 100644 (file)
@@ -3359,7 +3359,7 @@ class Uuid(TypeEngine[_UUID_RETURN]):
 
     __visit_name__ = "uuid"
 
-    collation = None
+    collation: Optional[str] = None
 
     @overload
     def __init__(
diff --git a/test/ext/mypy/plain_files/composite_dc.py b/test/ext/mypy/plain_files/composite_dc.py
new file mode 100644 (file)
index 0000000..fa1b16a
--- /dev/null
@@ -0,0 +1,51 @@
+import dataclasses
+
+from sqlalchemy import select
+from sqlalchemy.orm import composite
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+@dataclasses.dataclass
+class Point:
+    def __init__(self, x: int, y: int):
+        self.x = x
+        self.y = y
+
+
+class Vertex(Base):
+    __tablename__ = "vertices"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    x1: Mapped[int]
+    y1: Mapped[int]
+    x2: Mapped[int]
+    y2: Mapped[int]
+
+    # inferred from right hand side
+    start = composite(Point, "x1", "y1")
+
+    # taken from left hand side
+    end: Mapped[Point] = composite(Point, "x2", "y2")
+
+
+v1 = Vertex(start=Point(3, 4), end=Point(5, 6))
+
+stmt = select(Vertex).where(Vertex.start.in_([Point(3, 4)]))
+
+# EXPECTED_TYPE: Select[Tuple[Vertex]]
+reveal_type(stmt)
+
+# EXPECTED_TYPE: composite.Point
+reveal_type(v1.start)
+
+# EXPECTED_TYPE: composite.Point
+reveal_type(v1.end)
+
+# EXPECTED_TYPE: int
+reveal_type(v1.end.y)
index 01849a8ee9f216b8efbc9227dc2b613b7c07b67b..afaa099b21326d91b1a6428ef8d613033a20bd92 100644 (file)
@@ -1088,6 +1088,31 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                     mapped_column(), mapped_column(), mapped_column("zip")
                 )
 
+    def test_cls_not_composite_compliant(self, decl_base):
+        class Address:
+            def __init__(self, street: int, state: str, zip_: str):
+                pass
+
+            street: str
+            state: str
+            zip_: str
+
+        with expect_raises_message(
+            ArgumentError,
+            r"Composite class column arguments must be "
+            r"named unless a dataclass is used",
+        ):
+
+            class User(decl_base):
+                __tablename__ = "user"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+                name: Mapped[str] = mapped_column()
+
+                address: Mapped[Address] = composite(
+                    mapped_column(), mapped_column(), mapped_column("zip")
+                )
+
     def test_fwd_ref_ok_explicit_cls(self, decl_base):
         @dataclasses.dataclass
         class Address:
index b8d9d900804f91a10db8b1385b8017869dc2c51e..9f3c52325d5ed275debd7a4d762bd973e4fb8cbf 100644 (file)
@@ -1,3 +1,5 @@
+import dataclasses
+
 import sqlalchemy as sa
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
@@ -745,6 +747,33 @@ class PrimaryKeyTest(fixtures.MappedTest):
         eq_(g.version, g2.version)
 
 
+class PrimaryKeyTestDataclasses(PrimaryKeyTest):
+    @classmethod
+    def setup_mappers(cls):
+        graphs = cls.tables.graphs
+
+        @dataclasses.dataclass
+        class Version:
+            id: int
+            version: int
+
+        cls.classes.Version = Version
+
+        class Graph(cls.Comparable):
+            def __init__(self, version):
+                self.version = version
+
+        cls.mapper_registry.map_imperatively(
+            Graph,
+            graphs,
+            properties={
+                "version": sa.orm.composite(
+                    Version, graphs.c.id, graphs.c.version_id
+                )
+            },
+        )
+
+
 class DefaultsTest(fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):