from .base import Mapped
from .decl_base import _add_attribute
from .decl_base import _as_declarative
+from .decl_base import _ClassScanMapperConfig
from .decl_base import _declarative_constructor
from .decl_base import _DeferredMapperConfig
from .decl_base import _del_attribute
from .. import exc
from .. import inspection
from .. import util
+from ..sql.base import _NoArg
from ..sql.elements import SQLCoreOperations
from ..sql.schema import MetaData
from ..sql.selectable import FromClause
if TYPE_CHECKING:
from ._typing import _O
from ._typing import _RegistryType
+ from .decl_base import _DataclassArguments
from .instrumentation import ClassManager
from .interfaces import MapperProperty
from .state import InstanceState # noqa
from ..sql._typing import _TypeEngineArgument
-
_T = TypeVar("_T", bound=Any)
# it's not clear how to have Annotated, Union objects etc. as keys here
def __init_subclass__(
cls,
- init: bool = True,
- repr: bool = True, # noqa: A002
- eq: bool = True,
- order: bool = False,
- unsafe_hash: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ eq: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ order: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
) -> None:
- cls._sa_apply_dc_transforms = {
+
+ apply_dc_transforms: _DataclassArguments = {
"init": init,
"repr": repr,
"eq": eq,
"order": order,
"unsafe_hash": unsafe_hash,
}
+
+ if hasattr(cls, "_sa_apply_dc_transforms"):
+ current = cls._sa_apply_dc_transforms # type: ignore[attr-defined]
+
+ _ClassScanMapperConfig._assert_dc_arguments(current)
+
+ cls._sa_apply_dc_transforms = {
+ k: current.get(k, _NoArg.NO_ARG) if v is _NoArg.NO_ARG else v
+ for k, v in apply_dc_transforms.items()
+ }
+ else:
+ cls._sa_apply_dc_transforms = apply_dc_transforms
+
super().__init_subclass__()
self,
__cls: Literal[None] = ...,
*,
- init: bool = True,
- repr: bool = True, # noqa: A002
- eq: bool = True,
- order: bool = False,
- unsafe_hash: bool = False,
+ init: Union[_NoArg, bool] = ...,
+ repr: Union[_NoArg, bool] = ..., # noqa: A002
+ eq: Union[_NoArg, bool] = ...,
+ order: Union[_NoArg, bool] = ...,
+ unsafe_hash: Union[_NoArg, bool] = ...,
) -> Callable[[Type[_O]], Type[_O]]:
...
self,
__cls: Optional[Type[_O]] = None,
*,
- init: bool = True,
- repr: bool = True, # noqa: A002
- eq: bool = True,
- order: bool = False,
- unsafe_hash: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ eq: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ order: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]:
"""Class decorator that will apply the Declarative mapping process
to a given class, and additionally convert the class to be a
from ..util import topological
from ..util.typing import _AnnotationScanType
from ..util.typing import Protocol
+from ..util.typing import TypedDict
if TYPE_CHECKING:
from ._typing import _ClassDict
__mapper_args__: Mapping[str, Any]
__table_args__: Optional[_TableArgsType]
+ _sa_apply_dc_transforms: Optional[_DataclassArguments]
+
def __declare_first__(self) -> None:
pass
pass
+class _DataclassArguments(TypedDict):
+ init: Union[_NoArg, bool]
+ repr: Union[_NoArg, bool]
+ eq: Union[_NoArg, bool]
+ order: Union[_NoArg, bool]
+ unsafe_hash: Union[_NoArg, bool]
+
+
def _declared_mapping_info(
cls: Type[Any],
) -> Optional[Union[_DeferredMapperConfig, Mapper[Any]]]:
mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
inherits: Optional[Type[Any]]
- dataclass_setup_arguments: Optional[Dict[str, Any]]
+ dataclass_setup_arguments: Optional[_DataclassArguments]
"""if the class has SQLAlchemy native dataclass parameters, where
- we will create a SQLAlchemy dataclass (not a real dataclass).
+ we will turn the class into a dataclass within the declarative mapping
+ process.
"""
setattr(self.cls, k, v)
self.cls.__annotations__ = annotations
- dataclasses.dataclass(self.cls, **dataclass_setup_arguments)
+ self._assert_dc_arguments(dataclass_setup_arguments)
+
+ dataclasses.dataclass(
+ self.cls,
+ **{
+ k: v
+ for k, v in dataclass_setup_arguments.items()
+ if v is not _NoArg.NO_ARG
+ },
+ )
+
+ @classmethod
+ def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None:
+ disallowed_args = set(arguments).difference(
+ {
+ "init",
+ "repr",
+ "order",
+ "eq",
+ "unsafe_hash",
+ }
+ )
+ if disallowed_args:
+ raise exc.ArgumentError(
+ f"Dataclass argument(s) "
+ f"""{
+ ', '.join(f'{arg!r}'
+ for arg in sorted(disallowed_args))
+ } are not accepted"""
+ )
def _collect_annotation(
self,
import dataclasses
import inspect as pyinspect
+from itertools import product
from typing import Any
from typing import List
from typing import Optional
class DataclassArgsTest(fixtures.TestBase):
dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash")
- @testing.fixture(params=dc_arg_names)
+ @testing.fixture(params=product(dc_arg_names, (True, False)))
def dc_argument_fixture(self, request: Any, registry: _RegistryType):
- name = request.param
+ name, use_defaults = request.param
args = {n: n == name for n in self.dc_arg_names}
if args["order"]:
args["eq"] = True
- yield args
+ if use_defaults:
+ default = {
+ "init": True,
+ "repr": True,
+ "eq": True,
+ "order": False,
+ "unsafe_hash": False,
+ }
+ to_apply = {k: v for k, v in args.items() if v}
+ effective = {**default, **to_apply}
+ return to_apply, effective
+ else:
+ return args, args
@testing.fixture(
params=["mapped_column", "synonym", "deferred", "column_property"]
mapped_expr_constructor,
registry: _RegistryType,
):
- @registry.mapped_as_dataclass(**dc_argument_fixture)
+ @registry.mapped_as_dataclass(**dc_argument_fixture[0])
class A:
__tablename__ = "a"
x: Mapped[Optional[int]] = mapped_expr_constructor
- self._assert_cls(A, dc_argument_fixture)
+ self._assert_cls(A, dc_argument_fixture[1])
def test_dc_arguments_base(
self,
):
reg = registry
- class Base(MappedAsDataclass, DeclarativeBase, **dc_argument_fixture):
+ class Base(
+ MappedAsDataclass, DeclarativeBase, **dc_argument_fixture[0]
+ ):
registry = reg
class A(Base):
x: Mapped[Optional[int]] = mapped_expr_constructor
- self.A = A
+ self._assert_cls(A, dc_argument_fixture[1])
def test_dc_arguments_perclass(
self,
mapped_expr_constructor,
decl_base: Type[DeclarativeBase],
):
- class A(MappedAsDataclass, decl_base, **dc_argument_fixture):
+ class A(MappedAsDataclass, decl_base, **dc_argument_fixture[0]):
__tablename__ = "a"
id: Mapped[int] = mapped_column(primary_key=True, init=False)
x: Mapped[Optional[int]] = mapped_expr_constructor
- self.A = A
+ self._assert_cls(A, dc_argument_fixture[1])
+
+ def test_dc_arguments_override_base(self, registry: _RegistryType):
+ reg = registry
+
+ class Base(MappedAsDataclass, DeclarativeBase, init=False, order=True):
+ registry = reg
+
+ class A(Base, init=True, repr=False):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ some_int: Mapped[int] = mapped_column(init=False, repr=False)
+
+ x: Mapped[Optional[int]] = mapped_column(default=7)
+
+ effective = {
+ "init": True,
+ "repr": False,
+ "eq": True,
+ "order": True,
+ "unsafe_hash": False,
+ }
+ self._assert_cls(A, effective)
+
+ def test_dc_base_unsupported_argument(self, registry: _RegistryType):
+ reg = registry
+ with expect_raises(TypeError):
+
+ class Base(MappedAsDataclass, DeclarativeBase, slots=True):
+ registry = reg
+
+ class Base2(MappedAsDataclass, DeclarativeBase, order=True):
+ registry = reg
+
+ with expect_raises(TypeError):
+
+ class A(Base2, slots=False):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+ def test_dc_decorator_unsupported_argument(self, registry: _RegistryType):
+ reg = registry
+ with expect_raises(TypeError):
+
+ @registry.mapped_as_dataclass(slots=True)
+ class Base(DeclarativeBase):
+ registry = reg
+
+ class Base2(MappedAsDataclass, DeclarativeBase, order=True):
+ registry = reg
+
+ with expect_raises(TypeError):
+
+ @registry.mapped_as_dataclass(slots=True)
+ class A(Base2):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+ def test_dc_raise_for_slots(
+ self,
+ registry: _RegistryType,
+ decl_base: Type[DeclarativeBase],
+ ):
+ reg = registry
+ with expect_raises_message(
+ exc.ArgumentError,
+ r"Dataclass argument\(s\) 'slots', 'unknown' are not accepted",
+ ):
+
+ class A(MappedAsDataclass, decl_base):
+ __tablename__ = "a"
+ _sa_apply_dc_transforms = {"slots": True, "unknown": 5}
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+ with expect_raises_message(
+ exc.ArgumentError,
+ r"Dataclass argument\(s\) 'slots' are not accepted",
+ ):
+
+ class Base(MappedAsDataclass, DeclarativeBase, order=True):
+ registry = reg
+ _sa_apply_dc_transforms = {"slots": True}
+
+ with expect_raises_message(
+ exc.ArgumentError,
+ r"Dataclass argument\(s\) 'slots', 'unknown' are not accepted",
+ ):
+
+ @reg.mapped
+ class C:
+ __tablename__ = "a"
+ _sa_apply_dc_transforms = {"slots": True, "unknown": 5}
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):