]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improvements on dataclass_transform feature
authorFederico Caselli <cfederico87@gmail.com>
Sat, 21 May 2022 09:32:37 +0000 (11:32 +0200)
committermike bayer <mike_mp@zzzcomputing.com>
Sun, 22 May 2022 16:58:48 +0000 (16:58 +0000)
Change-Id: Iaf80526b70368cd4ed4147fdce9f6525b113474a

lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
test/orm/declarative/test_dc_transforms.py

index 553a50107f950ca8b7a09b5f5176cfd270b15668..feeda98f83851e951b7784ca58686b4405a89068 100644 (file)
@@ -45,6 +45,7 @@ from .base import _inspect_mapped_class
 from .base import Mapped
 from .decl_base import _add_attribute
 from .decl_base import _as_declarative
+from .decl_base import _ClassScanMapperConfig
 from .decl_base import _declarative_constructor
 from .decl_base import _DeferredMapperConfig
 from .decl_base import _del_attribute
@@ -60,6 +61,7 @@ from .state import InstanceState
 from .. import exc
 from .. import inspection
 from .. import util
+from ..sql.base import _NoArg
 from ..sql.elements import SQLCoreOperations
 from ..sql.schema import MetaData
 from ..sql.selectable import FromClause
@@ -72,11 +74,11 @@ from ..util.typing import Literal
 if TYPE_CHECKING:
     from ._typing import _O
     from ._typing import _RegistryType
+    from .decl_base import _DataclassArguments
     from .instrumentation import ClassManager
     from .interfaces import MapperProperty
     from .state import InstanceState  # noqa
     from ..sql._typing import _TypeEngineArgument
-
 _T = TypeVar("_T", bound=Any)
 
 # it's not clear how to have Annotated, Union objects etc. as keys here
@@ -588,19 +590,33 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
 
     def __init_subclass__(
         cls,
-        init: bool = True,
-        repr: bool = True,  # noqa: A002
-        eq: bool = True,
-        order: bool = False,
-        unsafe_hash: bool = False,
+        init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+        eq: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        order: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
     ) -> None:
-        cls._sa_apply_dc_transforms = {
+
+        apply_dc_transforms: _DataclassArguments = {
             "init": init,
             "repr": repr,
             "eq": eq,
             "order": order,
             "unsafe_hash": unsafe_hash,
         }
+
+        if hasattr(cls, "_sa_apply_dc_transforms"):
+            current = cls._sa_apply_dc_transforms  # type: ignore[attr-defined]
+
+            _ClassScanMapperConfig._assert_dc_arguments(current)
+
+            cls._sa_apply_dc_transforms = {
+                k: current.get(k, _NoArg.NO_ARG) if v is _NoArg.NO_ARG else v
+                for k, v in apply_dc_transforms.items()
+            }
+        else:
+            cls._sa_apply_dc_transforms = apply_dc_transforms
+
         super().__init_subclass__()
 
 
@@ -1229,11 +1245,11 @@ class registry:
         self,
         __cls: Literal[None] = ...,
         *,
-        init: bool = True,
-        repr: bool = True,  # noqa: A002
-        eq: bool = True,
-        order: bool = False,
-        unsafe_hash: bool = False,
+        init: Union[_NoArg, bool] = ...,
+        repr: Union[_NoArg, bool] = ...,  # noqa: A002
+        eq: Union[_NoArg, bool] = ...,
+        order: Union[_NoArg, bool] = ...,
+        unsafe_hash: Union[_NoArg, bool] = ...,
     ) -> Callable[[Type[_O]], Type[_O]]:
         ...
 
@@ -1241,11 +1257,11 @@ class registry:
         self,
         __cls: Optional[Type[_O]] = None,
         *,
-        init: bool = True,
-        repr: bool = True,  # noqa: A002
-        eq: bool = True,
-        order: bool = False,
-        unsafe_hash: bool = False,
+        init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        repr: Union[_NoArg, bool] = _NoArg.NO_ARG,  # noqa: A002
+        eq: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        order: Union[_NoArg, bool] = _NoArg.NO_ARG,
+        unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
     ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]:
         """Class decorator that will apply the Declarative mapping process
         to a given class, and additionally convert the class to be a
index 54a272f86e9920c79668fc17b784d74fe3c5247c..1e7c0eaf6ab51446ed91af37568ead541fd6c157 100644 (file)
@@ -64,6 +64,7 @@ from ..sql.schema import Table
 from ..util import topological
 from ..util.typing import _AnnotationScanType
 from ..util.typing import Protocol
+from ..util.typing import TypedDict
 
 if TYPE_CHECKING:
     from ._typing import _ClassDict
@@ -89,6 +90,8 @@ class _DeclMappedClassProtocol(Protocol[_O]):
     __mapper_args__: Mapping[str, Any]
     __table_args__: Optional[_TableArgsType]
 
+    _sa_apply_dc_transforms: Optional[_DataclassArguments]
+
     def __declare_first__(self) -> None:
         pass
 
@@ -96,6 +99,14 @@ class _DeclMappedClassProtocol(Protocol[_O]):
         pass
 
 
+class _DataclassArguments(TypedDict):
+    init: Union[_NoArg, bool]
+    repr: Union[_NoArg, bool]
+    eq: Union[_NoArg, bool]
+    order: Union[_NoArg, bool]
+    unsafe_hash: Union[_NoArg, bool]
+
+
 def _declared_mapping_info(
     cls: Type[Any],
 ) -> Optional[Union[_DeferredMapperConfig, Mapper[Any]]]:
@@ -419,9 +430,10 @@ class _ClassScanMapperConfig(_MapperConfig):
     mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
     inherits: Optional[Type[Any]]
 
-    dataclass_setup_arguments: Optional[Dict[str, Any]]
+    dataclass_setup_arguments: Optional[_DataclassArguments]
     """if the class has SQLAlchemy native dataclass parameters, where
-    we will create a SQLAlchemy dataclass (not a real dataclass).
+    we will turn the class into a dataclass within the declarative mapping
+    process.
 
     """
 
@@ -956,7 +968,36 @@ class _ClassScanMapperConfig(_MapperConfig):
             setattr(self.cls, k, v)
         self.cls.__annotations__ = annotations
 
-        dataclasses.dataclass(self.cls, **dataclass_setup_arguments)
+        self._assert_dc_arguments(dataclass_setup_arguments)
+
+        dataclasses.dataclass(
+            self.cls,
+            **{
+                k: v
+                for k, v in dataclass_setup_arguments.items()
+                if v is not _NoArg.NO_ARG
+            },
+        )
+
+    @classmethod
+    def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None:
+        disallowed_args = set(arguments).difference(
+            {
+                "init",
+                "repr",
+                "order",
+                "eq",
+                "unsafe_hash",
+            }
+        )
+        if disallowed_args:
+            raise exc.ArgumentError(
+                f"Dataclass argument(s) "
+                f"""{
+                    ', '.join(f'{arg!r}'
+                    for arg in sorted(disallowed_args))
+                } are not accepted"""
+            )
 
     def _collect_annotation(
         self,
index aac87372320f32cb04d1213cd3fc529be5d77332..308ebfeb17aa6b8b9b1707f1897c038ca1aef371 100644 (file)
@@ -1,5 +1,6 @@
 import dataclasses
 import inspect as pyinspect
+from itertools import product
 from typing import Any
 from typing import List
 from typing import Optional
@@ -488,14 +489,26 @@ class RelationshipDefaultFactoryTest(fixtures.TestBase):
 class DataclassArgsTest(fixtures.TestBase):
     dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash")
 
-    @testing.fixture(params=dc_arg_names)
+    @testing.fixture(params=product(dc_arg_names, (True, False)))
     def dc_argument_fixture(self, request: Any, registry: _RegistryType):
-        name = request.param
+        name, use_defaults = request.param
 
         args = {n: n == name for n in self.dc_arg_names}
         if args["order"]:
             args["eq"] = True
-        yield args
+        if use_defaults:
+            default = {
+                "init": True,
+                "repr": True,
+                "eq": True,
+                "order": False,
+                "unsafe_hash": False,
+            }
+            to_apply = {k: v for k, v in args.items() if v}
+            effective = {**default, **to_apply}
+            return to_apply, effective
+        else:
+            return args, args
 
     @testing.fixture(
         params=["mapped_column", "synonym", "deferred", "column_property"]
@@ -674,7 +687,7 @@ class DataclassArgsTest(fixtures.TestBase):
         mapped_expr_constructor,
         registry: _RegistryType,
     ):
-        @registry.mapped_as_dataclass(**dc_argument_fixture)
+        @registry.mapped_as_dataclass(**dc_argument_fixture[0])
         class A:
             __tablename__ = "a"
 
@@ -685,7 +698,7 @@ class DataclassArgsTest(fixtures.TestBase):
 
             x: Mapped[Optional[int]] = mapped_expr_constructor
 
-        self._assert_cls(A, dc_argument_fixture)
+        self._assert_cls(A, dc_argument_fixture[1])
 
     def test_dc_arguments_base(
         self,
@@ -695,7 +708,9 @@ class DataclassArgsTest(fixtures.TestBase):
     ):
         reg = registry
 
-        class Base(MappedAsDataclass, DeclarativeBase, **dc_argument_fixture):
+        class Base(
+            MappedAsDataclass, DeclarativeBase, **dc_argument_fixture[0]
+        ):
             registry = reg
 
         class A(Base):
@@ -708,7 +723,7 @@ class DataclassArgsTest(fixtures.TestBase):
 
             x: Mapped[Optional[int]] = mapped_expr_constructor
 
-        self.A = A
+        self._assert_cls(A, dc_argument_fixture[1])
 
     def test_dc_arguments_perclass(
         self,
@@ -716,7 +731,7 @@ class DataclassArgsTest(fixtures.TestBase):
         mapped_expr_constructor,
         decl_base: Type[DeclarativeBase],
     ):
-        class A(MappedAsDataclass, decl_base, **dc_argument_fixture):
+        class A(MappedAsDataclass, decl_base, **dc_argument_fixture[0]):
             __tablename__ = "a"
 
             id: Mapped[int] = mapped_column(primary_key=True, init=False)
@@ -726,7 +741,106 @@ class DataclassArgsTest(fixtures.TestBase):
 
             x: Mapped[Optional[int]] = mapped_expr_constructor
 
-        self.A = A
+        self._assert_cls(A, dc_argument_fixture[1])
+
+    def test_dc_arguments_override_base(self, registry: _RegistryType):
+        reg = registry
+
+        class Base(MappedAsDataclass, DeclarativeBase, init=False, order=True):
+            registry = reg
+
+        class A(Base, init=True, repr=False):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: Mapped[str]
+
+            some_int: Mapped[int] = mapped_column(init=False, repr=False)
+
+            x: Mapped[Optional[int]] = mapped_column(default=7)
+
+        effective = {
+            "init": True,
+            "repr": False,
+            "eq": True,
+            "order": True,
+            "unsafe_hash": False,
+        }
+        self._assert_cls(A, effective)
+
+    def test_dc_base_unsupported_argument(self, registry: _RegistryType):
+        reg = registry
+        with expect_raises(TypeError):
+
+            class Base(MappedAsDataclass, DeclarativeBase, slots=True):
+                registry = reg
+
+        class Base2(MappedAsDataclass, DeclarativeBase, order=True):
+            registry = reg
+
+        with expect_raises(TypeError):
+
+            class A(Base2, slots=False):
+                __tablename__ = "a"
+
+                id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+    def test_dc_decorator_unsupported_argument(self, registry: _RegistryType):
+        reg = registry
+        with expect_raises(TypeError):
+
+            @registry.mapped_as_dataclass(slots=True)
+            class Base(DeclarativeBase):
+                registry = reg
+
+        class Base2(MappedAsDataclass, DeclarativeBase, order=True):
+            registry = reg
+
+        with expect_raises(TypeError):
+
+            @registry.mapped_as_dataclass(slots=True)
+            class A(Base2):
+                __tablename__ = "a"
+
+                id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+    def test_dc_raise_for_slots(
+        self,
+        registry: _RegistryType,
+        decl_base: Type[DeclarativeBase],
+    ):
+        reg = registry
+        with expect_raises_message(
+            exc.ArgumentError,
+            r"Dataclass argument\(s\) 'slots', 'unknown' are not accepted",
+        ):
+
+            class A(MappedAsDataclass, decl_base):
+                __tablename__ = "a"
+                _sa_apply_dc_transforms = {"slots": True, "unknown": 5}
+
+                id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            r"Dataclass argument\(s\) 'slots' are not accepted",
+        ):
+
+            class Base(MappedAsDataclass, DeclarativeBase, order=True):
+                registry = reg
+                _sa_apply_dc_transforms = {"slots": True}
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            r"Dataclass argument\(s\) 'slots', 'unknown' are not accepted",
+        ):
+
+            @reg.mapped
+            class C:
+                __tablename__ = "a"
+                _sa_apply_dc_transforms = {"slots": True, "unknown": 5}
+
+                id: Mapped[int] = mapped_column(primary_key=True, init=False)
 
 
 class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):