unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
match_args: Union[_NoArg, bool] = _NoArg.NO_ARG,
kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ dataclass_callable: Union[
+ _NoArg, Callable[..., Type[Any]]
+ ] = _NoArg.NO_ARG,
) -> None:
apply_dc_transforms: _DataclassArguments = {
"init": init,
"unsafe_hash": unsafe_hash,
"match_args": match_args,
"kw_only": kw_only,
+ "dataclass_callable": dataclass_callable,
}
current_transforms: _DataclassArguments
super().__init_subclass__()
if not _is_mapped_class(cls):
+ new_anno = (
+ _ClassScanMapperConfig._update_annotations_for_non_mapped_class
+ )(cls)
_ClassScanMapperConfig._apply_dataclasses_to_any_class(
- current_transforms, cls
+ current_transforms, cls, new_anno
)
unsafe_hash: Union[_NoArg, bool] = ...,
match_args: Union[_NoArg, bool] = ...,
kw_only: Union[_NoArg, bool] = ...,
+ dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] = ...,
) -> Callable[[Type[_O]], Type[_O]]:
...
unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG,
match_args: Union[_NoArg, bool] = _NoArg.NO_ARG,
kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ dataclass_callable: Union[
+ _NoArg, Callable[..., Type[Any]]
+ ] = _NoArg.NO_ARG,
) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]:
"""Class decorator that will apply the Declarative mapping process
to a given class, and additionally convert the class to be a
"unsafe_hash": unsafe_hash,
"match_args": match_args,
"kw_only": kw_only,
+ "dataclass_callable": dataclass_callable,
}
_as_declarative(self, cls, cls.__dict__)
return cls
unsafe_hash: Union[_NoArg, bool]
match_args: Union[_NoArg, bool]
kw_only: Union[_NoArg, bool]
+ dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]]
def _declared_mapping_info(
for k, v in defaults.items():
setattr(self.cls, k, v)
- self.cls.__annotations__ = annotations
-
self._apply_dataclasses_to_any_class(
- dataclass_setup_arguments, self.cls
+ dataclass_setup_arguments, self.cls, annotations
)
+ @classmethod
+ def _update_annotations_for_non_mapped_class(
+ cls, klass: Type[_O]
+ ) -> Mapping[str, _AnnotationScanType]:
+ cls_annotations = util.get_annotations(klass)
+
+ new_anno = {}
+ for name, annotation in cls_annotations.items():
+ if _is_mapped_annotation(annotation, klass, klass):
+
+ extracted = _extract_mapped_subtype(
+ annotation,
+ klass,
+ klass.__module__,
+ name,
+ type(None),
+ required=False,
+ is_dataclass_field=False,
+ expect_mapped=False,
+ )
+ if extracted:
+ inner, _ = extracted
+ new_anno[name] = inner
+ else:
+ new_anno[name] = annotation
+ return new_anno
+
@classmethod
def _apply_dataclasses_to_any_class(
- cls, dataclass_setup_arguments: _DataclassArguments, klass: Type[_O]
+ cls,
+ dataclass_setup_arguments: _DataclassArguments,
+ klass: Type[_O],
+ use_annotations: Mapping[str, _AnnotationScanType],
) -> None:
cls._assert_dc_arguments(dataclass_setup_arguments)
- dataclasses.dataclass(
- klass,
- **{
- k: v
- for k, v in dataclass_setup_arguments.items()
- if v is not _NoArg.NO_ARG
- },
- )
+ dataclass_callable = dataclass_setup_arguments["dataclass_callable"]
+ if dataclass_callable is _NoArg.NO_ARG:
+ dataclass_callable = dataclasses.dataclass
+
+ restored: Optional[Any]
+
+ if use_annotations:
+ # apply constructed annotations that should look "normal" to a
+ # dataclasses callable, based on the fields present. This
+ # means remove the Mapped[] container and ensure all Field
+ # entries have an annotation
+ restored = getattr(klass, "__annotations__", None)
+ klass.__annotations__ = cast("Dict[str, Any]", use_annotations)
+ else:
+ restored = None
+
+ try:
+ dataclass_callable(
+ klass,
+ **{
+ k: v
+ for k, v in dataclass_setup_arguments.items()
+ if v is not _NoArg.NO_ARG and k != "dataclass_callable"
+ },
+ )
+ finally:
+ # restore original annotations outside of the dataclasses
+ # process; for mixins and __abstract__ superclasses, SQLAlchemy
+ # Declarative will need to see the Mapped[] container inside the
+ # annotations in order to map subclasses
+ if use_annotations:
+ if restored is None:
+ del klass.__annotations__
+ else:
+ klass.__annotations__ = restored
@classmethod
def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None:
"unsafe_hash",
"kw_only",
"match_args",
+ "dataclass_callable",
}
disallowed_args = set(arguments).difference(allowed)
if disallowed_args:
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import MappedAsDataclass
from sqlalchemy.orm import MappedColumn
+from sqlalchemy.orm import registry
from sqlalchemy.orm import registry as _RegistryType
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
from sqlalchemy.testing import ne_
+from sqlalchemy.testing import Variation
from sqlalchemy.util import compat
a3 = A("data")
eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])")
+ @testing.variation("dc_type", ["decorator", "superclass"])
+ def test_dataclass_fn(self, dc_type: Variation):
+ annotations = {}
+
+ def dc_callable(kls, **kw) -> Type[Any]:
+ annotations[kls] = kls.__annotations__
+ return dataclasses.dataclass(kls, **kw) # type: ignore
+
+ if dc_type.decorator:
+ reg = registry()
+
+ @reg.mapped_as_dataclass(dataclass_callable=dc_callable)
+ class MappedClass:
+ __tablename__ = "mapped_class"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str]
+
+ eq_(annotations, {MappedClass: {"id": int, "name": str}})
+
+ elif dc_type.superclass:
+
+ class Base(DeclarativeBase):
+ pass
+
+ class Mixin(MappedAsDataclass, dataclass_callable=dc_callable):
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ class MappedClass(Mixin, Base):
+ __tablename__ = "mapped_class"
+ name: Mapped[str]
+
+ eq_(
+ annotations,
+ {Mixin: {"id": int}, MappedClass: {"id": int, "name": str}},
+ )
+ else:
+ dc_type.fail()
+
def test_default_fn(self, dc_decl_base: Type[MappedAsDataclass]):
class A(dc_decl_base):
__tablename__ = "a"
eq_regex(repr(c1), r".*\.Child\(a=10, b=7, c=9\)")
def test_abstract_is_dc(self):
+ collected_annotations = {}
+
+ def check_args(cls, **kw):
+ collected_annotations[cls] = cls.__annotations__
+ return dataclasses.dataclass(cls, **kw)
+
class Parent(DeclarativeBase):
a: int
- class Mixin(MappedAsDataclass, Parent):
+ class Mixin(MappedAsDataclass, Parent, dataclass_callable=check_args):
__abstract__ = True
b: int
__tablename__ = "child"
c: Mapped[int] = mapped_column(primary_key=True)
+ eq_(collected_annotations, {Mixin: {"b": int}, Child: {"c": int}})
+ eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)")
+
+ @testing.variation("check_annotations", [True, False])
+ def test_abstract_is_dc_w_mapped(self, check_annotations):
+ if check_annotations:
+ collected_annotations = {}
+
+ def check_args(cls, **kw):
+ collected_annotations[cls] = cls.__annotations__
+ return dataclasses.dataclass(cls, **kw)
+
+ class_kw = {"dataclass_callable": check_args}
+ else:
+ class_kw = {}
+
+ class Parent(DeclarativeBase):
+ a: int
+
+ class Mixin(MappedAsDataclass, Parent, **class_kw):
+ __abstract__ = True
+ b: Mapped[int] = mapped_column()
+
+ class Child(Mixin):
+ __tablename__ = "child"
+ c: Mapped[int] = mapped_column(primary_key=True)
+
+ if check_annotations:
+ # note: current dataclasses process adds Field() object to Child
+ # based on attributes which include those from Mixin. This means
+ # the annotations of Child are also augmented while we do
+ # dataclasses collection.
+ eq_(
+ collected_annotations,
+ {Mixin: {"b": int}, Child: {"b": int, "c": int}},
+ )
eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)")
def test_mixin_and_base_is_dc(self):
"dataclass_scope",
["on_base", "on_mixin", "on_base_class", "on_sub_class"],
)
- def test_mixin_w_inheritance(self, dataclass_scope):
+ @testing.variation(
+ "test_alternative_callable",
+ [True, False],
+ )
+ def test_mixin_w_inheritance(
+ self, dataclass_scope, test_alternative_callable
+ ):
"""test #9226"""
+ expected_annotations = {}
+
+ if test_alternative_callable:
+ collected_annotations = {}
+
+ def check_args(cls, **kw):
+ collected_annotations[cls] = getattr(
+ cls, "__annotations__", {}
+ )
+ return dataclasses.dataclass(cls, **kw)
+
+ klass_kw = {"dataclass_callable": check_args}
+ else:
+ klass_kw = {}
+
if dataclass_scope.on_base:
- class Base(DeclarativeBase, MappedAsDataclass):
+ class Base(MappedAsDataclass, DeclarativeBase, **klass_kw):
pass
+ expected_annotations[Base] = {}
else:
class Base(DeclarativeBase):
if dataclass_scope.on_mixin:
- class Mixin(MappedAsDataclass):
+ class Mixin(MappedAsDataclass, **klass_kw):
@declared_attr.directive
@classmethod
def __tablename__(cls) -> str:
init=False,
)
+ expected_annotations[Mixin] = {}
else:
class Mixin:
if dataclass_scope.on_base_class:
- class Book(Mixin, MappedAsDataclass, Base):
+ class Book(Mixin, MappedAsDataclass, Base, **klass_kw):
id: Mapped[int] = mapped_column(
Integer,
primary_key=True,
primary_key=True,
)
+ if MappedAsDataclass in Book.__mro__:
+ expected_annotations[Book] = {"id": int, "polymorphic_type": str}
+
if dataclass_scope.on_sub_class:
- class Novel(MappedAsDataclass, Book):
+ class Novel(MappedAsDataclass, Book, **klass_kw):
id: Mapped[int] = mapped_column( # noqa: A001
ForeignKey("book.id"),
primary_key=True,
)
description: Mapped[Optional[str]]
+ expected_annotations[Novel] = {"id": int, "description": Optional[str]}
+
+ if test_alternative_callable:
+ eq_(collected_annotations, expected_annotations)
+
n1 = Novel("the description")
eq_(n1.description, "the description")
x: Mapped[int] = mapped_expr_constructor
def _assert_cls(self, cls, dc_arguments):
-
if dc_arguments["init"]:
def create(data, x):
eq_(a3.x, 7)
def _assert_not_init(self, cls, create, dc_arguments):
-
with expect_raises(TypeError):
cls("Some data", 5)
@testing.fixture
def model(self):
- def go(use_mixin, use_inherits, mad_setup):
-
+ def go(use_mixin, use_inherits, mad_setup, dataclass_kw):
if use_mixin:
-
if mad_setup == "dc, mad":
- class BaseEntity(DeclarativeBase, MappedAsDataclass):
+ class BaseEntity(
+ DeclarativeBase, MappedAsDataclass, **dataclass_kw
+ ):
pass
elif mad_setup == "mad, dc":
- class BaseEntity(MappedAsDataclass, DeclarativeBase):
+ class BaseEntity(
+ MappedAsDataclass, DeclarativeBase, **dataclass_kw
+ ):
pass
elif mad_setup == "subclass":
if mad_setup == "subclass":
- class A(IdMixin, MappedAsDataclass, BaseEntity):
+ class A(
+ IdMixin, MappedAsDataclass, BaseEntity, **dataclass_kw
+ ):
__mapper_args__ = {
"polymorphic_on": "type",
"polymorphic_identity": "a",
data: Mapped[str] = mapped_column(String, init=False)
else:
-
if mad_setup == "dc, mad":
- class BaseEntity(DeclarativeBase, MappedAsDataclass):
+ class BaseEntity(
+ DeclarativeBase, MappedAsDataclass, **dataclass_kw
+ ):
id: Mapped[int] = mapped_column(
primary_key=True, init=False
)
elif mad_setup == "mad, dc":
- class BaseEntity(MappedAsDataclass, DeclarativeBase):
+ class BaseEntity(
+ MappedAsDataclass, DeclarativeBase, **dataclass_kw
+ ):
id: Mapped[int] = mapped_column(
primary_key=True, init=False
)
if mad_setup == "subclass":
- class A(MappedAsDataclass, BaseEntity):
+ class A(MappedAsDataclass, BaseEntity, **dataclass_kw):
__mapper_args__ = {
"polymorphic_on": "type",
"polymorphic_identity": "a",
use_inherits=use_inherits == "inherits",
use_mixin=use_mixin == "mixin",
mad_setup=mad_setup,
+ dataclass_kw={},
)
obj = target_cls()