From: Federico Caselli Date: Sat, 21 May 2022 09:32:37 +0000 (+0200) Subject: Improvements on dataclass_transform feature X-Git-Tag: rel_2_0_0b1~294^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1c30e66e5d2d085a6e1975532d7edfc0b94b3a48;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improvements on dataclass_transform feature Change-Id: Iaf80526b70368cd4ed4147fdce9f6525b113474a --- diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 553a50107f..feeda98f83 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -45,6 +45,7 @@ from .base import _inspect_mapped_class from .base import Mapped from .decl_base import _add_attribute from .decl_base import _as_declarative +from .decl_base import _ClassScanMapperConfig from .decl_base import _declarative_constructor from .decl_base import _DeferredMapperConfig from .decl_base import _del_attribute @@ -60,6 +61,7 @@ from .state import InstanceState from .. import exc from .. import inspection from .. import util +from ..sql.base import _NoArg from ..sql.elements import SQLCoreOperations from ..sql.schema import MetaData from ..sql.selectable import FromClause @@ -72,11 +74,11 @@ from ..util.typing import Literal if TYPE_CHECKING: from ._typing import _O from ._typing import _RegistryType + from .decl_base import _DataclassArguments from .instrumentation import ClassManager from .interfaces import MapperProperty from .state import InstanceState # noqa from ..sql._typing import _TypeEngineArgument - _T = TypeVar("_T", bound=Any) # it's not clear how to have Annotated, Union objects etc. as keys here @@ -588,19 +590,33 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): def __init_subclass__( cls, - init: bool = True, - repr: bool = True, # noqa: A002 - eq: bool = True, - order: bool = False, - unsafe_hash: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + eq: Union[_NoArg, bool] = _NoArg.NO_ARG, + order: Union[_NoArg, bool] = _NoArg.NO_ARG, + unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, ) -> None: - cls._sa_apply_dc_transforms = { + + apply_dc_transforms: _DataclassArguments = { "init": init, "repr": repr, "eq": eq, "order": order, "unsafe_hash": unsafe_hash, } + + if hasattr(cls, "_sa_apply_dc_transforms"): + current = cls._sa_apply_dc_transforms # type: ignore[attr-defined] + + _ClassScanMapperConfig._assert_dc_arguments(current) + + cls._sa_apply_dc_transforms = { + k: current.get(k, _NoArg.NO_ARG) if v is _NoArg.NO_ARG else v + for k, v in apply_dc_transforms.items() + } + else: + cls._sa_apply_dc_transforms = apply_dc_transforms + super().__init_subclass__() @@ -1229,11 +1245,11 @@ class registry: self, __cls: Literal[None] = ..., *, - init: bool = True, - repr: bool = True, # noqa: A002 - eq: bool = True, - order: bool = False, - unsafe_hash: bool = False, + init: Union[_NoArg, bool] = ..., + repr: Union[_NoArg, bool] = ..., # noqa: A002 + eq: Union[_NoArg, bool] = ..., + order: Union[_NoArg, bool] = ..., + unsafe_hash: Union[_NoArg, bool] = ..., ) -> Callable[[Type[_O]], Type[_O]]: ... @@ -1241,11 +1257,11 @@ class registry: self, __cls: Optional[Type[_O]] = None, *, - init: bool = True, - repr: bool = True, # noqa: A002 - eq: bool = True, - order: bool = False, - unsafe_hash: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + eq: Union[_NoArg, bool] = _NoArg.NO_ARG, + order: Union[_NoArg, bool] = _NoArg.NO_ARG, + unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]: """Class decorator that will apply the Declarative mapping process to a given class, and additionally convert the class to be a diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 54a272f86e..1e7c0eaf6a 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -64,6 +64,7 @@ from ..sql.schema import Table from ..util import topological from ..util.typing import _AnnotationScanType from ..util.typing import Protocol +from ..util.typing import TypedDict if TYPE_CHECKING: from ._typing import _ClassDict @@ -89,6 +90,8 @@ class _DeclMappedClassProtocol(Protocol[_O]): __mapper_args__: Mapping[str, Any] __table_args__: Optional[_TableArgsType] + _sa_apply_dc_transforms: Optional[_DataclassArguments] + def __declare_first__(self) -> None: pass @@ -96,6 +99,14 @@ class _DeclMappedClassProtocol(Protocol[_O]): pass +class _DataclassArguments(TypedDict): + init: Union[_NoArg, bool] + repr: Union[_NoArg, bool] + eq: Union[_NoArg, bool] + order: Union[_NoArg, bool] + unsafe_hash: Union[_NoArg, bool] + + def _declared_mapping_info( cls: Type[Any], ) -> Optional[Union[_DeferredMapperConfig, Mapper[Any]]]: @@ -419,9 +430,10 @@ class _ClassScanMapperConfig(_MapperConfig): mapper_args_fn: Optional[Callable[[], Dict[str, Any]]] inherits: Optional[Type[Any]] - dataclass_setup_arguments: Optional[Dict[str, Any]] + dataclass_setup_arguments: Optional[_DataclassArguments] """if the class has SQLAlchemy native dataclass parameters, where - we will create a SQLAlchemy dataclass (not a real dataclass). + we will turn the class into a dataclass within the declarative mapping + process. """ @@ -956,7 +968,36 @@ class _ClassScanMapperConfig(_MapperConfig): setattr(self.cls, k, v) self.cls.__annotations__ = annotations - dataclasses.dataclass(self.cls, **dataclass_setup_arguments) + self._assert_dc_arguments(dataclass_setup_arguments) + + dataclasses.dataclass( + self.cls, + **{ + k: v + for k, v in dataclass_setup_arguments.items() + if v is not _NoArg.NO_ARG + }, + ) + + @classmethod + def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None: + disallowed_args = set(arguments).difference( + { + "init", + "repr", + "order", + "eq", + "unsafe_hash", + } + ) + if disallowed_args: + raise exc.ArgumentError( + f"Dataclass argument(s) " + f"""{ + ', '.join(f'{arg!r}' + for arg in sorted(disallowed_args)) + } are not accepted""" + ) def _collect_annotation( self, diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index aac8737232..308ebfeb17 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -1,5 +1,6 @@ import dataclasses import inspect as pyinspect +from itertools import product from typing import Any from typing import List from typing import Optional @@ -488,14 +489,26 @@ class RelationshipDefaultFactoryTest(fixtures.TestBase): class DataclassArgsTest(fixtures.TestBase): dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash") - @testing.fixture(params=dc_arg_names) + @testing.fixture(params=product(dc_arg_names, (True, False))) def dc_argument_fixture(self, request: Any, registry: _RegistryType): - name = request.param + name, use_defaults = request.param args = {n: n == name for n in self.dc_arg_names} if args["order"]: args["eq"] = True - yield args + if use_defaults: + default = { + "init": True, + "repr": True, + "eq": True, + "order": False, + "unsafe_hash": False, + } + to_apply = {k: v for k, v in args.items() if v} + effective = {**default, **to_apply} + return to_apply, effective + else: + return args, args @testing.fixture( params=["mapped_column", "synonym", "deferred", "column_property"] @@ -674,7 +687,7 @@ class DataclassArgsTest(fixtures.TestBase): mapped_expr_constructor, registry: _RegistryType, ): - @registry.mapped_as_dataclass(**dc_argument_fixture) + @registry.mapped_as_dataclass(**dc_argument_fixture[0]) class A: __tablename__ = "a" @@ -685,7 +698,7 @@ class DataclassArgsTest(fixtures.TestBase): x: Mapped[Optional[int]] = mapped_expr_constructor - self._assert_cls(A, dc_argument_fixture) + self._assert_cls(A, dc_argument_fixture[1]) def test_dc_arguments_base( self, @@ -695,7 +708,9 @@ class DataclassArgsTest(fixtures.TestBase): ): reg = registry - class Base(MappedAsDataclass, DeclarativeBase, **dc_argument_fixture): + class Base( + MappedAsDataclass, DeclarativeBase, **dc_argument_fixture[0] + ): registry = reg class A(Base): @@ -708,7 +723,7 @@ class DataclassArgsTest(fixtures.TestBase): x: Mapped[Optional[int]] = mapped_expr_constructor - self.A = A + self._assert_cls(A, dc_argument_fixture[1]) def test_dc_arguments_perclass( self, @@ -716,7 +731,7 @@ class DataclassArgsTest(fixtures.TestBase): mapped_expr_constructor, decl_base: Type[DeclarativeBase], ): - class A(MappedAsDataclass, decl_base, **dc_argument_fixture): + class A(MappedAsDataclass, decl_base, **dc_argument_fixture[0]): __tablename__ = "a" id: Mapped[int] = mapped_column(primary_key=True, init=False) @@ -726,7 +741,106 @@ class DataclassArgsTest(fixtures.TestBase): x: Mapped[Optional[int]] = mapped_expr_constructor - self.A = A + self._assert_cls(A, dc_argument_fixture[1]) + + def test_dc_arguments_override_base(self, registry: _RegistryType): + reg = registry + + class Base(MappedAsDataclass, DeclarativeBase, init=False, order=True): + registry = reg + + class A(Base, init=True, repr=False): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_column(default=7) + + effective = { + "init": True, + "repr": False, + "eq": True, + "order": True, + "unsafe_hash": False, + } + self._assert_cls(A, effective) + + def test_dc_base_unsupported_argument(self, registry: _RegistryType): + reg = registry + with expect_raises(TypeError): + + class Base(MappedAsDataclass, DeclarativeBase, slots=True): + registry = reg + + class Base2(MappedAsDataclass, DeclarativeBase, order=True): + registry = reg + + with expect_raises(TypeError): + + class A(Base2, slots=False): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + def test_dc_decorator_unsupported_argument(self, registry: _RegistryType): + reg = registry + with expect_raises(TypeError): + + @registry.mapped_as_dataclass(slots=True) + class Base(DeclarativeBase): + registry = reg + + class Base2(MappedAsDataclass, DeclarativeBase, order=True): + registry = reg + + with expect_raises(TypeError): + + @registry.mapped_as_dataclass(slots=True) + class A(Base2): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + def test_dc_raise_for_slots( + self, + registry: _RegistryType, + decl_base: Type[DeclarativeBase], + ): + reg = registry + with expect_raises_message( + exc.ArgumentError, + r"Dataclass argument\(s\) 'slots', 'unknown' are not accepted", + ): + + class A(MappedAsDataclass, decl_base): + __tablename__ = "a" + _sa_apply_dc_transforms = {"slots": True, "unknown": 5} + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + with expect_raises_message( + exc.ArgumentError, + r"Dataclass argument\(s\) 'slots' are not accepted", + ): + + class Base(MappedAsDataclass, DeclarativeBase, order=True): + registry = reg + _sa_apply_dc_transforms = {"slots": True} + + with expect_raises_message( + exc.ArgumentError, + r"Dataclass argument\(s\) 'slots', 'unknown' are not accepted", + ): + + @reg.mapped + class C: + __tablename__ = "a" + _sa_apply_dc_transforms = {"slots": True, "unknown": 5} + + id: Mapped[int] = mapped_column(primary_key=True, init=False) class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):