from .decl_api import DeclarativeMeta as DeclarativeMeta
from .decl_api import declared_attr as declared_attr
from .decl_api import has_inherited_table as has_inherited_table
+from .decl_api import MappedAsDataclass as MappedAsDataclass
from .decl_api import registry as registry
from .decl_api import synonym_for as synonym_for
from .descriptor_props import Composite as Composite
from . import mapperlib as mapperlib
from ._typing import _O
-from .base import Mapped
from .descriptor_props import Composite
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
from .properties import ColumnProperty
from .properties import MappedColumn
from .query import AliasOption
from .. import sql
from .. import util
from ..exc import InvalidRequestError
+from ..sql._typing import _no_kw
+from ..sql.base import _NoArg
from ..sql.base import SchemaEventTarget
from ..sql.schema import SchemaConst
from ..sql.selectable import FromClause
Union[_TypeEngineArgument[Any], SchemaEventTarget]
] = None,
*args: SchemaEventTarget,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
nullable: Optional[
Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]]
] = SchemaConst.NULL_UNSPECIFIED,
name: Optional[str] = None,
type_: Optional[_TypeEngineArgument[Any]] = None,
autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
- default: Optional[Any] = None,
doc: Optional[str] = None,
key: Optional[str] = None,
index: Optional[bool] = None,
type_=type_,
autoincrement=autoincrement,
default=default,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
doc=doc,
key=key,
index=index,
deferred: bool = False,
raiseload: bool = False,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
active_history: bool = False,
expire_on_flush: bool = True,
info: Optional[_InfoType] = None,
return ColumnProperty(
column,
*additional_columns,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
group=group,
deferred=deferred,
raiseload=raiseload,
@overload
def composite(
- class_: Type[_CC],
+ _class_or_attr: Type[_CC],
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[_CC]:
...
@overload
def composite(
+ _class_or_attr: _CompositeAttrType[Any],
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[Any]:
...
def composite(
- class_: Any = None,
+ _class_or_attr: Union[
+ None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
+ ] = None,
*attrs: _CompositeAttrType[Any],
- **kwargs: Any,
+ group: Optional[str] = None,
+ deferred: bool = False,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None,
+ active_history: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+ **__kw: Any,
) -> Composite[Any]:
r"""Return a composite column-based property for use with a Mapper.
:attr:`.MapperProperty.info` attribute of this object.
"""
- return Composite(class_, *attrs, **kwargs)
+ if __kw:
+ raise _no_kw()
+
+ return Composite(
+ _class_or_attr,
+ *attrs,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
+ group=group,
+ deferred=deferred,
+ raiseload=raiseload,
+ comparator_factory=comparator_factory,
+ active_history=active_history,
+ info=info,
+ doc=doc,
+ )
def with_loader_criteria(
post_update: bool = False,
cascade: str = "save-update, merge",
viewonly: bool = False,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Union[_NoArg, _T] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
lazy: _LazyLoadArgumentType = "select",
passive_deletes: Union[Literal["all"], bool] = False,
passive_updates: bool = True,
post_update=post_update,
cascade=cascade,
viewonly=viewonly,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
lazy=lazy,
passive_deletes=passive_deletes,
passive_updates=passive_updates,
map_column: Optional[bool] = None,
descriptor: Optional[Any] = None,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Union[_NoArg, _T] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
info: Optional[_InfoType] = None,
doc: Optional[str] = None,
) -> Synonym[Any]:
map_column=map_column,
descriptor=descriptor,
comparator_factory=comparator_factory,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
doc=doc,
info=info,
)
def deferred(
column: _ORMColumnExprArgument[_T],
*additional_columns: _ORMColumnExprArgument[Any],
- **kw: Any,
+ group: Optional[str] = None,
+ raiseload: bool = False,
+ comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ init: Union[_NoArg, bool] = _NoArg.NO_ARG,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ default: Optional[Any] = _NoArg.NO_ARG,
+ default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG,
+ active_history: bool = False,
+ expire_on_flush: bool = True,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
) -> ColumnProperty[_T]:
r"""Indicate a column-based mapped attribute that by default will
not load unless accessed.
:ref:`deferred_raiseload`
- :param \**kw: additional keyword arguments passed to
- :class:`.ColumnProperty`.
+ Additional arguments are the same as that of :func:`_orm.column_property`.
.. seealso::
:ref:`deferred`
"""
- kw["deferred"] = True
- return ColumnProperty(column, *additional_columns, **kw)
+ return ColumnProperty(
+ column,
+ *additional_columns,
+ attribute_options=_AttributeOptions(
+ init,
+ repr,
+ default,
+ default_factory,
+ ),
+ group=group,
+ deferred=True,
+ raiseload=raiseload,
+ comparator_factory=comparator_factory,
+ active_history=active_history,
+ expire_on_flush=expire_on_flush,
+ info=info,
+ doc=doc,
+ )
def query_expression(
default_expr: _ORMColumnExprArgument[_T] = sql.null(),
-) -> Mapped[_T]:
+ *,
+ repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002
+ expire_on_flush: bool = True,
+ info: Optional[_InfoType] = None,
+ doc: Optional[str] = None,
+) -> ColumnProperty[_T]:
"""Indicate an attribute that populates from a query-time SQL expression.
:param default_expr: Optional SQL expression object that will be used in
:ref:`mapper_querytime_expression`
"""
- prop = ColumnProperty(default_expr)
+ prop = ColumnProperty(
+ default_expr,
+ attribute_options=_AttributeOptions(
+ _NoArg.NO_ARG,
+ repr,
+ _NoArg.NO_ARG,
+ _NoArg.NO_ARG,
+ ),
+ expire_on_flush=expire_on_flush,
+ info=info,
+ doc=doc,
+ )
prop.strategy_key = (("query_expression", True),)
return prop
from . import instrumentation
from . import interfaces
from . import mapperlib
+from ._orm_constructors import column_property
+from ._orm_constructors import composite
+from ._orm_constructors import deferred
+from ._orm_constructors import mapped_column
+from ._orm_constructors import query_expression
+from ._orm_constructors import relationship
+from ._orm_constructors import synonym
from .attributes import InstrumentedAttribute
from .base import _inspect_mapped_class
from .base import Mapped
from .decl_base import _DeferredMapperConfig
from .decl_base import _del_attribute
from .decl_base import _mapper
+from .descriptor_props import Composite
+from .descriptor_props import Synonym
from .descriptor_props import Synonym as _orm_synonym
from .mapper import Mapper
+from .properties import ColumnProperty
+from .properties import MappedColumn
+from .relationships import Relationship
from .state import InstanceState
from .. import exc
from .. import inspection
if TYPE_CHECKING:
from ._typing import _O
from ._typing import _RegistryType
- from .descriptor_props import Synonym
from .instrumentation import ClassManager
from .interfaces import MapperProperty
+ from .state import InstanceState # noqa
from ..sql._typing import _TypeEngineArgument
_T = TypeVar("_T", bound=Any)
"""
+@compat_typing.dataclass_transform(
+ field_descriptors=(
+ MappedColumn[Any],
+ Relationship[Any],
+ Composite[Any],
+ ColumnProperty[Any],
+ Synonym[Any],
+ mapped_column,
+ relationship,
+ composite,
+ column_property,
+ synonym,
+ deferred,
+ query_expression,
+ ),
+)
+class DCTransformDeclarative(DeclarativeAttributeIntercept):
+ """metaclass that includes @dataclass_transforms"""
+
+
class DeclarativeMeta(
_DynamicAttributesType, inspection.Inspectable[Mapper[Any]]
):
cls._sa_registry.map_declaratively(cls)
+class MappedAsDataclass(metaclass=DCTransformDeclarative):
+ """Mixin class to indicate when mapping this class, also convert it to be
+ a dataclass.
+
+ .. seealso::
+
+ :meth:`_orm.registry.mapped_as_dataclass`
+
+ .. versionadded:: 2.0
+ """
+
+ def __init_subclass__(
+ cls,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> None:
+ cls._sa_apply_dc_transforms = {
+ "init": init,
+ "repr": repr,
+ "eq": eq,
+ "order": order,
+ "unsafe_hash": unsafe_hash,
+ }
+ super().__init_subclass__()
+
+
class DeclarativeBase(
inspection.Inspectable[InstanceState[Any]],
metaclass=DeclarativeAttributeIntercept,
):
"""Base class used for declarative class definitions.
+
The :class:`_orm.DeclarativeBase` allows for the creation of new
declarative bases in such a way that is compatible with type checkers::
bases = not isinstance(cls, tuple) and (cls,) or cls
- class_dict = dict(registry=self, metadata=metadata)
+ class_dict: Dict[str, Any] = dict(registry=self, metadata=metadata)
if isinstance(cls, type):
class_dict["__doc__"] = cls.__doc__
return metaclass(name, bases, class_dict)
+ @compat_typing.dataclass_transform(
+ field_descriptors=(
+ MappedColumn[Any],
+ Relationship[Any],
+ Composite[Any],
+ ColumnProperty[Any],
+ Synonym[Any],
+ mapped_column,
+ relationship,
+ composite,
+ column_property,
+ synonym,
+ deferred,
+ query_expression,
+ ),
+ )
+ @overload
+ def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]:
+ ...
+
+ @overload
+ def mapped_as_dataclass(
+ self,
+ __cls: Literal[None] = ...,
+ *,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> Callable[[Type[_O]], Type[_O]]:
+ ...
+
+ def mapped_as_dataclass(
+ self,
+ __cls: Optional[Type[_O]] = None,
+ *,
+ init: bool = True,
+ repr: bool = True, # noqa: A002
+ eq: bool = True,
+ order: bool = False,
+ unsafe_hash: bool = False,
+ ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]:
+ """Class decorator that will apply the Declarative mapping process
+ to a given class, and additionally convert the class to be a
+ Python dataclass.
+
+ .. seealso::
+
+ :meth:`_orm.registry.mapped`
+
+ .. versionadded:: 2.0
+
+
+ """
+
+ def decorate(cls: Type[_O]) -> Type[_O]:
+ cls._sa_apply_dc_transforms = {
+ "init": init,
+ "repr": repr,
+ "eq": eq,
+ "order": order,
+ "unsafe_hash": unsafe_hash,
+ }
+ _as_declarative(self, cls, cls.__dict__)
+ return cls
+
+ if __cls:
+ return decorate(__cls)
+ else:
+ return decorate
+
def mapped(self, cls: Type[_O]) -> Type[_O]:
"""Class decorator that will apply the Declarative mapping process
to a given class.
that will apply Declarative mapping to subclasses automatically
using a Python metaclass.
+ .. seealso::
+
+ :meth:`_orm.registry.mapped_as_dataclass`
+
"""
_as_declarative(self, cls, cls.__dict__)
return cls
from __future__ import annotations
import collections
+import dataclasses
+import re
from typing import Any
from typing import Callable
from typing import cast
from .base import InspectionAttr
from .descriptor_props import Composite
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MappedAttribute
from .interfaces import _MapsColumns
from .mapper import Mapper
from .properties import ColumnProperty
from .properties import MappedColumn
+from .util import _extract_mapped_subtype
from .util import _is_mapped_annotation
from .util import class_mapper
from .. import event
from .. import exc
from .. import util
from ..sql import expression
+from ..sql.base import _NoArg
from ..sql.schema import Column
from ..sql.schema import Table
from ..util import topological
+from ..util.typing import _AnnotationScanType
from ..util.typing import Protocol
if TYPE_CHECKING:
"mapper_args",
"mapper_args_fn",
"inherits",
+ "allow_dataclass_fields",
+ "dataclass_setup_arguments",
)
registry: _RegistryType
clsdict_view: _ClassDict
- collected_annotations: Dict[str, Tuple[Any, bool]]
+ collected_annotations: Dict[str, Tuple[Any, Any, bool]]
collected_attributes: Dict[str, Any]
local_table: Optional[FromClause]
persist_selectable: Optional[FromClause]
mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
inherits: Optional[Type[Any]]
+ dataclass_setup_arguments: Optional[Dict[str, Any]]
+ """if the class has SQLAlchemy native dataclass parameters, where
+ we will create a SQLAlchemy dataclass (not a real dataclass).
+
+ """
+
+ allow_dataclass_fields: bool
+ """if true, look for dataclass-processed Field objects on the target
+ class as well as superclasses and extract ORM mapping directives from
+ the "metadata" attribute of each Field"""
+
def __init__(
self,
registry: _RegistryType,
self.declared_columns = util.OrderedSet()
self.column_copies = {}
+ self.dataclass_setup_arguments = dca = getattr(
+ self.cls, "_sa_apply_dc_transforms", None
+ )
+
+ cld = dataclasses.is_dataclass(cls_)
+
+ sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__")
+
+ # we don't want to consume Field objects from a not-already-dataclass.
+ # the Field objects won't have their "name" or "type" populated,
+ # and while it seems like we could just set these on Field as we
+ # read them, Field is documented as "user read only" and we need to
+ # stay far away from any off-label use of dataclasses APIs.
+ if (not cld or dca) and sdk:
+ raise exc.InvalidRequestError(
+ "SQLAlchemy mapped dataclasses can't consume mapping "
+ "information from dataclass.Field() objects if the immediate "
+ "class is not already a dataclass."
+ )
+
+ # if already a dataclass, and __sa_dataclass_metadata_key__ present,
+ # then also look inside of dataclass.Field() objects yielded by
+ # dataclasses.get_fields(cls) when scanning for attributes
+ self.allow_dataclass_fields = bool(sdk and cld)
+
self._setup_declared_events()
self._scan_attributes()
+ self._setup_dataclasses_transforms()
+
with mapperlib._CONFIGURE_MUTEX:
clsregistry.add_class(
self.classname, self.cls, registry._class_registry
attribute, taking SQLAlchemy-enabled dataclass fields into account.
"""
- sa_dataclass_metadata_key = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__"
- )
- if sa_dataclass_metadata_key is None:
+ if self.allow_dataclass_fields:
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
+ )
+ else:
+ sa_dataclass_metadata_key = None
+
+ if not sa_dataclass_metadata_key:
def attribute_is_overridden(key: str, obj: Any) -> bool:
return getattr(cls, key) is not obj
"__dict__",
"__weakref__",
"_sa_class_manager",
+ "_sa_apply_dc_transforms",
"__dict__",
"__weakref__",
]
adjusting for SQLAlchemy fields embedded in dataclass fields.
"""
- sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
- cls, "__sa_dataclass_metadata_key__"
- )
-
cls_annotations = util.get_annotations(cls)
cls_vars = vars(cls)
names = util.merge_lists_w_ordering(
[n for n in cls_vars if n not in skip], list(cls_annotations)
)
- if sa_dataclass_metadata_key is None:
+
+ if self.allow_dataclass_fields:
+ sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__"
+ )
+ else:
+ sa_dataclass_metadata_key = None
+
+ if not sa_dataclass_metadata_key:
def local_attributes_for_class() -> Iterable[
Tuple[str, Any, Any, bool]
name,
obj,
annotation,
- is_dataclass,
+ is_dataclass_field,
) in local_attributes_for_class():
- if name == "__mapper_args__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not mapper_args_fn and (not class_mapped or check_decl):
- # don't even invoke __mapper_args__ until
- # after we've determined everything about the
- # mapped table.
- # make a copy of it so a class-level dictionary
- # is not overwritten when we update column-based
- # arguments.
- def _mapper_args_fn() -> Dict[str, Any]:
- return dict(cls_as_Decl.__mapper_args__)
-
- mapper_args_fn = _mapper_args_fn
-
- elif name == "__tablename__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not tablename and (not class_mapped or check_decl):
- tablename = cls_as_Decl.__tablename__
- elif name == "__table_args__":
- check_decl = _check_declared_props_nocascade(
- obj, name, cls
- )
- if not table_args and (not class_mapped or check_decl):
- table_args = cls_as_Decl.__table_args__
- if not isinstance(
- table_args, (tuple, dict, type(None))
+ if re.match(r"^__.+__$", name):
+ if name == "__mapper_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not mapper_args_fn and (
+ not class_mapped or check_decl
):
- raise exc.ArgumentError(
- "__table_args__ value must be a tuple, "
- "dict, or None"
- )
- if base is not cls:
- inherited_table_args = True
+ # don't even invoke __mapper_args__ until
+ # after we've determined everything about the
+ # mapped table.
+ # make a copy of it so a class-level dictionary
+ # is not overwritten when we update column-based
+ # arguments.
+ def _mapper_args_fn() -> Dict[str, Any]:
+ return dict(cls_as_Decl.__mapper_args__)
+
+ mapper_args_fn = _mapper_args_fn
+
+ elif name == "__tablename__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not tablename and (not class_mapped or check_decl):
+ tablename = cls_as_Decl.__tablename__
+ elif name == "__table_args__":
+ check_decl = _check_declared_props_nocascade(
+ obj, name, cls
+ )
+ if not table_args and (not class_mapped or check_decl):
+ table_args = cls_as_Decl.__table_args__
+ if not isinstance(
+ table_args, (tuple, dict, type(None))
+ ):
+ raise exc.ArgumentError(
+ "__table_args__ value must be a tuple, "
+ "dict, or None"
+ )
+ if base is not cls:
+ inherited_table_args = True
+ else:
+ # skip all other dunder names
+ continue
elif class_mapped:
if _is_declarative_props(obj):
util.warn(
# acting like that for now.
if isinstance(obj, (Column, MappedColumn)):
- self.collected_annotations[name] = (
- annotation,
- False,
+ self._collect_annotation(
+ name, annotation, is_dataclass_field, True, obj
)
# already copied columns to the mapped class.
continue
] = ret = obj.__get__(obj, cls)
setattr(cls, name, ret)
else:
- if is_dataclass:
+ if is_dataclass_field:
# access attribute using normal class access
# first, to see if it's been mapped on a
# superclass. note if the dataclasses.field()
):
ret.doc = obj.__doc__
- self.collected_annotations[name] = (
+ self._collect_annotation(
+ name,
obj._collect_return_annotation(),
False,
+ True,
+ obj,
)
elif _is_mapped_annotation(annotation, cls):
- self.collected_annotations[name] = (
- annotation,
- is_dataclass,
+ self._collect_annotation(
+ name, annotation, is_dataclass_field, True, obj
)
if obj is None:
if not fixed_table:
# declarative mapping. however, check for some
# more common mistakes
self._warn_for_decl_attributes(base, name, obj)
- elif is_dataclass and (
+ elif is_dataclass_field and (
name not in clsdict_view or clsdict_view[name] is not obj
):
# here, we are definitely looking at the target class
obj = obj.fget()
collected_attributes[name] = obj
- self.collected_annotations[name] = (
- annotation,
- True,
+ self._collect_annotation(
+ name, annotation, True, False, obj
)
else:
- self.collected_annotations[name] = (
- annotation,
- False,
+ self._collect_annotation(
+ name, annotation, False, None, obj
)
if (
obj is None
collected_attributes[name] = MappedColumn()
elif name in clsdict_view:
collected_attributes[name] = obj
+ # else if the name is not in the cls.__dict__,
+ # don't collect it as an attribute.
+ # we will see the annotation only, which is meaningful
+ # both for mapping and dataclasses setup
if inherited_table_args and not tablename:
table_args = None
self.tablename = tablename
self.mapper_args_fn = mapper_args_fn
+ def _setup_dataclasses_transforms(self) -> None:
+
+ dataclass_setup_arguments = self.dataclass_setup_arguments
+ if not dataclass_setup_arguments:
+ return
+
+ manager = instrumentation.manager_of_class(self.cls)
+ assert manager is not None
+
+ field_list = [
+ _AttributeOptions._get_arguments_for_make_dataclass(
+ key,
+ anno,
+ self.collected_attributes.get(key, _NoArg.NO_ARG),
+ )
+ for key, anno in (
+ (key, mapped_anno if mapped_anno else raw_anno)
+ for key, (
+ raw_anno,
+ mapped_anno,
+ is_dc,
+ ) in self.collected_annotations.items()
+ )
+ ]
+
+ annotations = {}
+ defaults = {}
+ for item in field_list:
+ if len(item) == 2:
+ name, tp = item # type: ignore
+ elif len(item) == 3:
+ name, tp, spec = item # type: ignore
+ defaults[name] = spec
+ else:
+ assert False
+ annotations[name] = tp
+
+ for k, v in defaults.items():
+ setattr(self.cls, k, v)
+ self.cls.__annotations__ = annotations
+
+ dataclasses.dataclass(self.cls, **dataclass_setup_arguments)
+
+ def _collect_annotation(
+ self,
+ name: str,
+ raw_annotation: _AnnotationScanType,
+ is_dataclass: bool,
+ expect_mapped: Optional[bool],
+ attr_value: Any,
+ ) -> None:
+
+ if expect_mapped is None:
+ expect_mapped = isinstance(attr_value, _MappedAttribute)
+
+ extracted_mapped_annotation = _extract_mapped_subtype(
+ raw_annotation,
+ self.cls,
+ name,
+ type(attr_value),
+ required=False,
+ is_dataclass_field=False,
+ expect_mapped=expect_mapped and not self.allow_dataclass_fields,
+ )
+
+ self.collected_annotations[name] = (
+ raw_annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ )
+
def _warn_for_decl_attributes(
self, cls: Type[Any], key: str, c: Any
) -> None:
_undefer_column_name(
k, self.column_copies.get(value, value) # type: ignore
)
- elif isinstance(value, _IntrospectsAnnotations):
- annotation, is_dataclass = self.collected_annotations.get(
- k, (None, False)
- )
- value.declarative_scan(
- self.registry, cls, k, annotation, is_dataclass
- )
+ else:
+ if isinstance(value, _IntrospectsAnnotations):
+ (
+ annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ ) = self.collected_annotations.get(k, (None, None, False))
+ value.declarative_scan(
+ self.registry,
+ cls,
+ k,
+ annotation,
+ extracted_mapped_annotation,
+ is_dataclass,
+ )
+
+ if (
+ isinstance(value, (MapperProperty, _MapsColumns))
+ and value._has_dataclass_arguments
+ and not self.dataclass_setup_arguments
+ ):
+ if isinstance(value, MapperProperty):
+ argnames = [
+ "init",
+ "default_factory",
+ "repr",
+ "default",
+ ]
+ else:
+ argnames = ["init", "default_factory", "repr"]
+
+ args = {
+ a
+ for a in argnames
+ if getattr(
+ value._attribute_options, f"dataclasses_{a}"
+ )
+ is not _NoArg.NO_ARG
+ }
+ raise exc.ArgumentError(
+ f"Attribute '{k}' on class {cls} includes dataclasses "
+ f"argument(s): "
+ f"{', '.join(sorted(repr(a) for a in args))} but "
+ f"class does not specify "
+ "SQLAlchemy native dataclass configuration."
+ )
+
our_stuff[k] = value
def _extract_declared_columns(self) -> None:
# extract columns from the class dict
declared_columns = self.declared_columns
name_to_prop_key = collections.defaultdict(set)
+
for key, c in list(our_stuff.items()):
if isinstance(c, _MapsColumns):
# otherwise, Mapper will map it under the column key.
if mp_to_assign is None and key != col.key:
our_stuff[key] = col
-
elif isinstance(c, Column):
# undefer previously occurred here, and now occurs earlier.
# ensure every column we get here has been named
from .base import Mapped
from .base import PassiveFlag
from .base import SQLORMOperations
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .interfaces import PropComparator
-from .util import _extract_mapped_subtype
from .util import _none_set
from .. import event
from .. import exc as sa_exc
def __init__(
self,
- class_: Union[
+ _class_or_attr: Union[
None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any]
] = None,
*attrs: _CompositeAttrType[Any],
+ attribute_options: Optional[_AttributeOptions] = None,
active_history: bool = False,
deferred: bool = False,
group: Optional[str] = None,
comparator_factory: Optional[Type[Comparator[_CC]]] = None,
info: Optional[_InfoType] = None,
+ **kwargs: Any,
):
- super().__init__()
+ super().__init__(attribute_options=attribute_options)
- if isinstance(class_, (Mapped, str, sql.ColumnElement)):
- self.attrs = (class_,) + attrs
+ if isinstance(_class_or_attr, (Mapped, str, sql.ColumnElement)):
+ self.attrs = (_class_or_attr,) + attrs
# will initialize within declarative_scan
self.composite_class = None # type: ignore
else:
- self.composite_class = class_ # type: ignore
+ self.composite_class = _class_or_attr # type: ignore
self.attrs = attrs
self.active_history = active_history
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
- MappedColumn = util.preloaded.orm_properties.MappedColumn
-
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- MappedColumn,
- self.composite_class is None,
- is_dataclass_field,
- )
-
+ if (
+ self.composite_class is None
+ and extracted_mapped_annotation is None
+ ):
+ self._raise_for_required(key, cls)
+ argument = extracted_mapped_annotation
if argument and self.composite_class is None:
if isinstance(argument, str) or hasattr(
argument, "__forward_arg__"
for param, attr in itertools.zip_longest(
insp.parameters.values(), self.attrs
):
- if param is None or attr is None:
+ if param is None:
raise sa_exc.ArgumentError(
- f"number of arguments to {self.composite_class.__name__} "
- f"class and number of attributes don't match"
+ f"number of composite attributes "
+ f"{len(self.attrs)} exceeds "
+ f"that of the number of attributes in class "
+ f"{self.composite_class.__name__} {len(insp.parameters)}"
)
+ if attr is None:
+ # fill in missing attr spots with empty MappedColumn
+ attr = MappedColumn()
+ self.attrs += (attr,)
+
if isinstance(attr, MappedColumn):
attr.declarative_scan_for_composite(
registry, cls, key, param.name, param.annotation
map_column: Optional[bool] = None,
descriptor: Optional[Any] = None,
comparator_factory: Optional[Type[PropComparator[_T]]] = None,
+ attribute_options: Optional[_AttributeOptions] = None,
info: Optional[_InfoType] = None,
doc: Optional[str] = None,
):
- super().__init__()
+ super().__init__(attribute_options=attribute_options)
self.name = name
self.map_column = map_column
"previously known as deferred_scalar_loader"
init_method: Optional[Callable[..., None]]
+ original_init: Optional[Callable[..., None]] = None
factory: Optional[_ManagerFactory]
if finalize and not self._finalized:
self._finalize()
- def _finalize(self):
+ def _finalize(self) -> None:
if self._finalized:
return
self._finalized = True
_instrumentation_factory.dispatch.class_instrument(self.class_)
- def __hash__(self):
+ def __hash__(self) -> int: # type: ignore[override]
return id(self)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return other is self
@property
- def is_mapped(self):
+ def is_mapped(self) -> bool:
return "mapper" in self.__dict__
@HasMemoized.memoized_attribute
from __future__ import annotations
import collections
+import dataclasses
import typing
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterator
from typing import List
+from typing import NamedTuple
+from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Set
from .base import RelationshipDirection as RelationshipDirection # noqa: F401
from .base import SQLORMOperations
from .. import ColumnElement
+from .. import exc as sa_exc
from .. import inspection
from .. import util
from ..sql import operators
from ..sql import roles
from ..sql import visitors
+from ..sql.base import _NoArg
from ..sql.base import ExecutableOption
from ..sql.cache_key import HasCacheKey
from ..sql.schema import Column
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
"""Perform class-specific initializaton at early declarative scanning
"""
+ def _raise_for_required(self, key: str, cls: Type[Any]) -> NoReturn:
+ raise sa_exc.ArgumentError(
+ f"Python typing annotation is required for attribute "
+ f'"{cls.__name__}.{key}" when primary argument(s) for '
+ f'"{self.__class__.__name__}" construct are None or not present'
+ )
+
+
+class _AttributeOptions(NamedTuple):
+ """define Python-local attribute behavior options common to all
+ :class:`.MapperProperty` objects.
+
+ Currently this includes dataclass-generation arguments.
+
+ .. versionadded:: 2.0
+
+ """
+
+ dataclasses_init: Union[_NoArg, bool]
+ dataclasses_repr: Union[_NoArg, bool]
+ dataclasses_default: Union[_NoArg, Any]
+ dataclasses_default_factory: Union[_NoArg, Callable[[], Any]]
+
+ def _as_dataclass_field(self) -> Any:
+ """Return a ``dataclasses.Field`` object given these arguments."""
+
+ kw: Dict[str, Any] = {}
+ if self.dataclasses_default_factory is not _NoArg.NO_ARG:
+ kw["default_factory"] = self.dataclasses_default_factory
+ if self.dataclasses_default is not _NoArg.NO_ARG:
+ kw["default"] = self.dataclasses_default
+ if self.dataclasses_init is not _NoArg.NO_ARG:
+ kw["init"] = self.dataclasses_init
+ if self.dataclasses_repr is not _NoArg.NO_ARG:
+ kw["repr"] = self.dataclasses_repr
+
+ return dataclasses.field(**kw)
+
+ @classmethod
+ def _get_arguments_for_make_dataclass(
+ cls, key: str, annotation: Type[Any], elem: _T
+ ) -> Union[
+ Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]]
+ ]:
+ """given attribute key, annotation, and value from a class, return
+ the argument tuple we would pass to dataclasses.make_dataclass()
+ for this attribute.
+
+ """
+ if isinstance(elem, (MapperProperty, _MapsColumns)):
+ dc_field = elem._attribute_options._as_dataclass_field()
+
+ return (key, annotation, dc_field)
+ elif elem is not _NoArg.NO_ARG:
+ # why is typing not erroring on this?
+ return (key, annotation, elem)
+ else:
+ return (key, annotation)
+
+
+_DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions(
+ _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG
+)
+
class _MapsColumns(_MappedAttribute[_T]):
"""interface for declarative-capable construct that delivers one or more
__slots__ = ()
+ _attribute_options: _AttributeOptions
+ _has_dataclass_arguments: bool
+
@property
def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]:
"""return a MapperProperty to be assigned to the declarative mapping"""
__slots__ = (
"_configure_started",
"_configure_finished",
+ "_attribute_options",
+ "_has_dataclass_arguments",
"parent",
"key",
"info",
doc: Optional[str]
"""optional documentation string"""
+ _attribute_options: _AttributeOptions
+ """behavioral options for ORM-enabled Python attributes
+
+ .. versionadded:: 2.0
+
+ """
+
+ _has_dataclass_arguments: bool
+
def _memoized_attr_info(self) -> _InfoType:
"""Info dictionary associated with the object, allowing user-defined
data to be associated with this :class:`.InspectionAttr`.
"""
- def __init__(self) -> None:
+ def __init__(
+ self, attribute_options: Optional[_AttributeOptions] = None
+ ) -> None:
self._configure_started = False
self._configure_finished = False
+ if (
+ attribute_options
+ and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS
+ ):
+ self._has_dataclass_arguments = True
+ self._attribute_options = attribute_options
+ else:
+ self._has_dataclass_arguments = False
+ self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS
def init(self) -> None:
"""Called after all mappers are created to assemble
from .descriptor_props import Composite
from .descriptor_props import ConcreteInheritedProperty
from .descriptor_props import Synonym
+from .interfaces import _AttributeOptions
+from .interfaces import _DEFAULT_ATTRIBUTE_OPTIONS
from .interfaces import _IntrospectsAnnotations
from .interfaces import _MapsColumns
from .interfaces import MapperProperty
from .interfaces import PropComparator
from .interfaces import StrategizedProperty
from .relationships import Relationship
-from .util import _extract_mapped_subtype
from .util import _orm_full_deannotate
from .. import exc as sa_exc
from .. import ForeignKey
from ..sql import coercions
from ..sql import roles
from ..sql import sqltypes
+from ..sql.base import _NoArg
from ..sql.elements import SQLCoreOperations
from ..sql.schema import Column
from ..sql.schema import SchemaConst
self,
column: _ORMColumnExprArgument[_T],
*additional_columns: _ORMColumnExprArgument[Any],
+ attribute_options: Optional[_AttributeOptions] = None,
group: Optional[str] = None,
deferred: bool = False,
raiseload: bool = False,
doc: Optional[str] = None,
_instrument: bool = True,
):
- super(ColumnProperty, self).__init__()
+ super(ColumnProperty, self).__init__(
+ attribute_options=attribute_options
+ )
columns = (column,) + additional_columns
self._orig_columns = [
coercions.expect(roles.LabeledColumnExprRole, c) for c in columns
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
column = self.columns[0]
"foreign_keys",
"_has_nullable",
"deferred",
+ "_attribute_options",
+ "_has_dataclass_arguments",
)
deferred: bool
column: Column[_T]
foreign_keys: Optional[Set[ForeignKey]]
+ _attribute_options: _AttributeOptions
def __init__(self, *arg: Any, **kw: Any):
+ self._attribute_options = attr_opts = kw.pop(
+ "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS
+ )
+
+ self._has_dataclass_arguments = False
+
+ if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS:
+ if attr_opts.dataclasses_default_factory is not _NoArg.NO_ARG:
+ self._has_dataclass_arguments = True
+ kw["default"] = attr_opts.dataclasses_default_factory
+ elif attr_opts.dataclasses_default is not _NoArg.NO_ARG:
+ kw["default"] = attr_opts.dataclasses_default
+
+ if (
+ attr_opts.dataclasses_init is not _NoArg.NO_ARG
+ or attr_opts.dataclasses_repr is not _NoArg.NO_ARG
+ ):
+ self._has_dataclass_arguments = True
+
+ if "default" in kw and kw["default"] is _NoArg.NO_ARG:
+ kw.pop("default")
+
self.deferred = kw.pop("deferred", False)
self.column = cast("Column[_T]", Column(*arg, **kw))
self.foreign_keys = self.column.foreign_keys
new.deferred = self.deferred
new.foreign_keys = new.column.foreign_keys
new._has_nullable = self._has_nullable
+ new._attribute_options = self._attribute_options
+ new._has_dataclass_arguments = self._has_dataclass_arguments
util.set_creation_order(new)
return new
@property
def mapper_property_to_assign(self) -> Optional["MapperProperty[_T]"]:
if self.deferred:
- return ColumnProperty(self.column, deferred=True)
+ return ColumnProperty(
+ self.column,
+ deferred=True,
+ attribute_options=self._attribute_options,
+ )
else:
return None
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
column = self.column
sqltype = column.type
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- MappedColumn,
- sqltype._isnull and not self.column.foreign_keys,
- is_dataclass_field,
- )
- if argument is None:
- return
+ if extracted_mapped_annotation is None:
+ if sqltype._isnull and not self.column.foreign_keys:
+ self._raise_for_required(key, cls)
+ else:
+ return
- self._init_column_for_annotation(cls, registry, argument)
+ self._init_column_for_annotation(
+ cls, registry, extracted_mapped_annotation
+ )
@util.preload_module("sqlalchemy.orm.decl_base")
def declarative_scan_for_composite(
from .base import LoaderCallableStatus
from .base import PassiveFlag
from .base import state_str
+from .interfaces import _AttributeOptions
from .interfaces import _IntrospectsAnnotations
from .interfaces import MANYTOMANY
from .interfaces import MANYTOONE
from .interfaces import PropComparator
from .interfaces import RelationshipDirection
from .interfaces import StrategizedProperty
-from .util import _extract_mapped_subtype
from .util import _orm_annotate
from .util import _orm_deannotate
from .util import CascadeOptions
post_update: bool = False,
cascade: str = "save-update, merge",
viewonly: bool = False,
+ attribute_options: Optional[_AttributeOptions] = None,
lazy: _LazyLoadArgumentType = "select",
passive_deletes: Union[Literal["all"], bool] = False,
passive_updates: bool = True,
_local_remote_pairs: Optional[_ColumnPairs] = None,
_legacy_inactive_history_style: bool = False,
):
- super(Relationship, self).__init__()
+ super(Relationship, self).__init__(attribute_options=attribute_options)
self.uselist = uselist
self.argument = argument
cls: Type[Any],
key: str,
annotation: Optional[_AnnotationScanType],
+ extracted_mapped_annotation: Optional[_AnnotationScanType],
is_dataclass_field: bool,
) -> None:
- argument = _extract_mapped_subtype(
- annotation,
- cls,
- key,
- Relationship,
- self.argument is None,
- is_dataclass_field,
- )
- if argument is None:
- return
+ argument = extracted_mapped_annotation
+
+ if extracted_mapped_annotation is None:
+
+ if self.argument is None:
+ self._raise_for_required(key, cls)
+ else:
+ return
+
+ argument = extracted_mapped_annotation
if hasattr(argument, "__origin__"):
def _is_mapped_annotation(
- raw_annotation: Union[type, str], cls: Type[Any]
+ raw_annotation: _AnnotationScanType, cls: Type[Any]
) -> bool:
annotated = de_stringify_annotation(cls, raw_annotation)
return is_origin_of(annotated, "Mapped", module="sqlalchemy.orm")
attr_cls: Type[Any],
required: bool,
is_dataclass_field: bool,
- superclasses: Optional[Tuple[Type[Any], ...]] = None,
+ expect_mapped: bool = True,
) -> Optional[Union[type, str]]:
+ """given an annotation, figure out if it's ``Mapped[something]`` and if
+ so, return the ``something`` part.
+ Includes error raise scenarios and other options.
+
+ """
if raw_annotation is None:
if required:
if is_dataclass_field:
return annotated
else:
- # TODO: there don't seem to be tests for the failure
- # conditions here
- if not hasattr(annotated, "__origin__") or (
- not issubclass(
- annotated.__origin__, # type: ignore
- superclasses if superclasses else attr_cls,
- )
- and not issubclass(attr_cls, annotated.__origin__) # type: ignore
+ if not hasattr(annotated, "__origin__") or not is_origin_of(
+ annotated, "Mapped", module="sqlalchemy.orm"
):
- our_annotated_str = (
- annotated.__name__
+ anno_name = (
+ getattr(annotated, "__name__", None)
if not isinstance(annotated, str)
- else repr(annotated)
- )
- raise sa_exc.ArgumentError(
- f'Type annotation for "{cls.__name__}.{key}" should use the '
- f'syntax "Mapped[{our_annotated_str}]" or '
- f'"{attr_cls.__name__}[{our_annotated_str}]".'
+ else None
)
+ if anno_name is None:
+ our_annotated_str = repr(annotated)
+ else:
+ our_annotated_str = anno_name
+
+ if expect_mapped:
+ raise sa_exc.ArgumentError(
+ f'Type annotation for "{cls.__name__}.{key}" '
+ "should use the "
+ f'syntax "Mapped[{our_annotated_str}]" or '
+ f'"{attr_cls.__name__}[{our_annotated_str}]".'
+ )
+
+ else:
+ return annotated
if len(annotated.__args__) != 1: # type: ignore
raise sa_exc.ArgumentError(
from .. import util
from ..orm import declarative_base
from ..orm import DeclarativeBase
+from ..orm import MappedAsDataclass
from ..orm import registry
from ..schema import sort_tables_and_constraints
@config.fixture()
def registry(self, metadata):
- reg = registry(metadata=metadata)
+ reg = registry(
+ metadata=metadata,
+ type_annotation_map={
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb"
+ )
+ },
+ )
yield reg
reg.dispose()
yield Base
Base.registry.dispose()
+ @config.fixture
+ def dc_decl_base(self, metadata):
+ _md = metadata
+
+ class Base(MappedAsDataclass, DeclarativeBase):
+ metadata = _md
+ type_annotation_map = {
+ str: sa.String().with_variant(
+ sa.String(50), "mysql", "mariadb"
+ )
+ }
+
+ yield Base
+ Base.registry.dispose()
+
@config.fixture()
def future_connection(self, future_engine, connection):
# integrate the future_engine and connection fixtures so
def dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated
- with a class."""
+ with a class as an already processed dataclass.
+
+ The class must **already be a dataclass** for Field objects to be returned.
+
+ """
if dataclasses.is_dataclass(cls):
return dataclasses.fields(cls)
def local_dataclass_fields(cls: Type[Any]) -> Iterable[dataclasses.Field[Any]]:
"""Return a sequence of all dataclasses.Field objects associated with
- a class, excluding those that originate from a superclass."""
+ an already processed dataclass, excluding those that originate from a
+ superclass.
+
+ The class must **already be a dataclass** for Field objects to be returned.
+
+ """
if dataclasses.is_dataclass(cls):
super_fields: Set[dataclasses.Field[Any]] = set()
from . import compat
+
+# more zimports issues
+if True:
+ from typing_extensions import ( # noqa: F401
+ dataclass_transform as dataclass_transform,
+ )
+
+
_T = TypeVar("_T", bound=Any)
_KT = TypeVar("_KT")
_KT_co = TypeVar("_KT_co", covariant=True)
[tool.pyright]
-
reportPrivateUsage = "none"
reportUnusedClass = "none"
reportUnusedFunction = "none"
install_requires =
importlib-metadata;python_version<"3.8"
greenlet != 0.4.17;(platform_machine=='aarch64' or (platform_machine=='ppc64le' or (platform_machine=='x86_64' or (platform_machine=='amd64' or (platform_machine=='AMD64' or (platform_machine=='win32' or platform_machine=='WIN32'))))))
- typing-extensions >= 4
+ typing-extensions >= 4.1.0
[options.extras_require]
asyncio =
--- /dev/null
+import dataclasses
+import inspect as pyinspect
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Type
+from unittest import mock
+
+from sqlalchemy import Column
+from sqlalchemy import exc
+from sqlalchemy import ForeignKey
+from sqlalchemy import inspect
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy.orm import column_property
+from sqlalchemy.orm import composite
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import deferred
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import MappedAsDataclass
+from sqlalchemy.orm import MappedColumn
+from sqlalchemy.orm import registry as _RegistryType
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import synonym
+from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import eq_regex
+from sqlalchemy.testing import expect_raises
+from sqlalchemy.testing import expect_raises_message
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
+from sqlalchemy.testing import ne_
+
+
+class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
+ def test_basic_constructor_repr_base_cls(
+ self, dc_decl_base: Type[MappedAsDataclass]
+ ):
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ bs: Mapped[List["B"]] = relationship( # noqa: F821
+ default_factory=list
+ )
+
+ class B(dc_decl_base):
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ A.__qualname__ = "some_module.A"
+ B.__qualname__ = "some_module.B"
+
+ eq_(
+ pyinspect.getfullargspec(A.__init__),
+ pyinspect.FullArgSpec(
+ args=["self", "data", "x", "bs"],
+ varargs=None,
+ varkw=None,
+ defaults=(None, mock.ANY),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+ eq_(
+ pyinspect.getfullargspec(B.__init__),
+ pyinspect.FullArgSpec(
+ args=["self", "data", "x"],
+ varargs=None,
+ varkw=None,
+ defaults=(None,),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+
+ a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)])
+ eq_(
+ repr(a2),
+ "some_module.A(id=None, data='10', x=5, "
+ "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), "
+ "some_module.B(id=None, data='data2', a_id=None, x=12)])",
+ )
+
+ a3 = A("data")
+ eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])")
+
+ def test_basic_constructor_repr_cls_decorator(
+ self, registry: _RegistryType
+ ):
+ @registry.mapped_as_dataclass()
+ class A:
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ bs: Mapped[List["B"]] = relationship( # noqa: F821
+ default_factory=list
+ )
+
+ @registry.mapped_as_dataclass()
+ class B:
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ A.__qualname__ = "some_module.A"
+ B.__qualname__ = "some_module.B"
+
+ eq_(
+ pyinspect.getfullargspec(A.__init__),
+ pyinspect.FullArgSpec(
+ args=["self", "data", "x", "bs"],
+ varargs=None,
+ varkw=None,
+ defaults=(None, mock.ANY),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+ eq_(
+ pyinspect.getfullargspec(B.__init__),
+ pyinspect.FullArgSpec(
+ args=["self", "data", "x"],
+ varargs=None,
+ varkw=None,
+ defaults=(None,),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+
+ a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)])
+ eq_(
+ repr(a2),
+ "some_module.A(id=None, data='10', x=5, "
+ "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), "
+ "some_module.B(id=None, data='data2', a_id=None, x=12)])",
+ )
+
+ a3 = A("data")
+ eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])")
+
+ def test_default_fn(self, dc_decl_base: Type[MappedAsDataclass]):
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str] = mapped_column(default="d1")
+ data2: Mapped[str] = mapped_column(default_factory=lambda: "d2")
+
+ a1 = A()
+ eq_(a1.data, "d1")
+ eq_(a1.data2, "d2")
+
+ def test_default_factory_vs_collection_class(
+ self, dc_decl_base: Type[MappedAsDataclass]
+ ):
+ # this is currently the error raised by dataclasses. We can instead
+ # do this validation ourselves, but overall I don't know that we
+ # can hit every validation and rule that's in dataclasses
+ with expect_raises_message(
+ ValueError, "cannot specify both default and default_factory"
+ ):
+
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str] = mapped_column(
+ default="d1", default_factory=lambda: "d2"
+ )
+
+ def test_inheritance(self, dc_decl_base: Type[MappedAsDataclass]):
+ class Person(dc_decl_base):
+ __tablename__ = "person"
+ person_id: Mapped[int] = mapped_column(
+ primary_key=True, init=False
+ )
+ name: Mapped[str]
+ type: Mapped[str] = mapped_column(init=False)
+
+ __mapper_args__ = {"polymorphic_on": type}
+
+ class Engineer(Person):
+ __tablename__ = "engineer"
+
+ person_id: Mapped[int] = mapped_column(
+ ForeignKey("person.person_id"), primary_key=True, init=False
+ )
+
+ status: Mapped[str] = mapped_column(String(30))
+ engineer_name: Mapped[str]
+ primary_language: Mapped[str]
+
+ e1 = Engineer("nm", "st", "en", "pl")
+ eq_(e1.name, "nm")
+ eq_(e1.status, "st")
+ eq_(e1.engineer_name, "en")
+ eq_(e1.primary_language, "pl")
+
+ def test_integrated_dc(self, dc_decl_base: Type[MappedAsDataclass]):
+ """We will be telling users "this is a dataclass that is also
+ mapped". Therefore, they will want *any* kind of attribute to do what
+ it would normally do in a dataclass, including normal types without any
+ field and explicit use of dataclasses.field(). additionally, we'd like
+ ``Mapped`` to mean "persist this attribute". So the absence of
+ ``Mapped`` should also mean something too.
+
+ """
+
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ ctrl_one: str
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+ some_field: int = dataclasses.field(default=5)
+
+ some_none_field: Optional[str] = None
+
+ a1 = A("ctrlone", "datafield")
+ eq_(a1.some_field, 5)
+ eq_(a1.some_none_field, None)
+
+ # only Mapped[] is mapped
+ self.assert_compile(select(A), "SELECT a.id, a.data FROM a")
+ eq_(
+ pyinspect.getfullargspec(A.__init__),
+ pyinspect.FullArgSpec(
+ args=[
+ "self",
+ "ctrl_one",
+ "data",
+ "some_field",
+ "some_none_field",
+ ],
+ varargs=None,
+ varkw=None,
+ defaults=(5, None),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+
+ def test_dc_on_top_of_non_dc(self, decl_base: Type[DeclarativeBase]):
+ class Person(decl_base):
+ __tablename__ = "person"
+ person_id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str]
+ type: Mapped[str] = mapped_column()
+
+ __mapper_args__ = {"polymorphic_on": type}
+
+ class Engineer(MappedAsDataclass, Person):
+ __tablename__ = "engineer"
+
+ person_id: Mapped[int] = mapped_column(
+ ForeignKey("person.person_id"), primary_key=True, init=False
+ )
+
+ status: Mapped[str] = mapped_column(String(30))
+ engineer_name: Mapped[str]
+ primary_language: Mapped[str]
+
+ e1 = Engineer("st", "en", "pl")
+ eq_(e1.status, "st")
+ eq_(e1.engineer_name, "en")
+ eq_(e1.primary_language, "pl")
+
+ eq_(
+ pyinspect.getfullargspec(Person.__init__),
+ # the boring **kw __init__
+ pyinspect.FullArgSpec(
+ args=["self"],
+ varargs=None,
+ varkw="kwargs",
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+
+ eq_(
+ pyinspect.getfullargspec(Engineer.__init__),
+ # the exciting dataclasses __init__
+ pyinspect.FullArgSpec(
+ args=["self", "status", "engineer_name", "primary_language"],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={},
+ ),
+ )
+
+
+class RelationshipDefaultFactoryTest(fixtures.TestBase):
+ def test_list(self, dc_decl_base: Type[MappedAsDataclass]):
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+ bs: Mapped[List["B"]] = relationship( # noqa: F821
+ default_factory=lambda: [B(data="hi")]
+ )
+
+ class B(dc_decl_base):
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+
+ a1 = A()
+ eq_(a1.bs[0].data, "hi")
+
+ def test_set(self, dc_decl_base: Type[MappedAsDataclass]):
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+ bs: Mapped[Set["B"]] = relationship( # noqa: F821
+ default_factory=lambda: {B(data="hi")}
+ )
+
+ class B(dc_decl_base, unsafe_hash=True):
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+
+ a1 = A()
+ eq_(a1.bs.pop().data, "hi")
+
+ def test_oh_no_mismatch(self, dc_decl_base: Type[MappedAsDataclass]):
+ class A(dc_decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+ bs: Mapped[Set["B"]] = relationship( # noqa: F821
+ default_factory=lambda: [B(data="hi")]
+ )
+
+ class B(dc_decl_base, unsafe_hash=True):
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+
+ # old school collection mismatch error FTW
+ with expect_raises_message(
+ TypeError, "Incompatible collection type: list is not set-like"
+ ):
+ A()
+
+ def test_replace_operation_works_w_history_etc(
+ self, registry: _RegistryType
+ ):
+ @registry.mapped_as_dataclass
+ class A:
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ bs: Mapped[List["B"]] = relationship( # noqa: F821
+ default_factory=list
+ )
+
+ @registry.mapped_as_dataclass
+ class B:
+ __tablename__ = "b"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ a_id = mapped_column(ForeignKey("a.id"), init=False)
+ data: Mapped[str]
+ x: Mapped[Optional[int]] = mapped_column(default=None)
+
+ registry.metadata.create_all(testing.db)
+
+ with Session(testing.db) as sess:
+ a1 = A("data", 10, [B("b1"), B("b2", x=5), B("b3")])
+ sess.add(a1)
+ sess.commit()
+
+ a2 = dataclasses.replace(a1, x=12, bs=[B("b4")])
+
+ assert a1 in sess
+ assert not sess.is_modified(a1, include_collections=True)
+ assert a2 not in sess
+ eq_(inspect(a2).attrs.x.history, ([12], (), ()))
+ sess.add(a2)
+ sess.commit()
+
+ eq_(sess.scalars(select(A.x).order_by(A.id)).all(), [10, 12])
+ eq_(
+ sess.scalars(select(B.data).order_by(B.id)).all(),
+ ["b1", "b2", "b3", "b4"],
+ )
+
+ def test_post_init(self, registry: _RegistryType):
+ @registry.mapped_as_dataclass
+ class A:
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str] = mapped_column(init=False)
+
+ def __post_init__(self):
+ self.data = "some data"
+
+ a1 = A()
+ eq_(a1.data, "some data")
+
+ def test_no_field_args_w_new_style(self, registry: _RegistryType):
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "SQLAlchemy mapped dataclasses can't consume mapping information",
+ ):
+
+ @registry.mapped_as_dataclass()
+ class A:
+ __tablename__ = "a"
+ __sa_dataclass_metadata_key__ = "sa"
+
+ account_id: int = dataclasses.field(
+ init=False,
+ metadata={"sa": Column(Integer, primary_key=True)},
+ )
+
+ def test_no_field_args_w_new_style_two(self, registry: _RegistryType):
+ @dataclasses.dataclass
+ class Base:
+ pass
+
+ with expect_raises_message(
+ exc.InvalidRequestError,
+ "SQLAlchemy mapped dataclasses can't consume mapping information",
+ ):
+
+ @registry.mapped_as_dataclass()
+ class A(Base):
+ __tablename__ = "a"
+ __sa_dataclass_metadata_key__ = "sa"
+
+ account_id: int = dataclasses.field(
+ init=False,
+ metadata={"sa": Column(Integer, primary_key=True)},
+ )
+
+
+class DataclassArgsTest(fixtures.TestBase):
+ dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash")
+
+ @testing.fixture(params=dc_arg_names)
+ def dc_argument_fixture(self, request: Any, registry: _RegistryType):
+ name = request.param
+
+ args = {n: n == name for n in self.dc_arg_names}
+ if args["order"]:
+ args["eq"] = True
+ yield args
+
+ @testing.fixture(
+ params=["mapped_column", "synonym", "deferred", "column_property"]
+ )
+ def mapped_expr_constructor(self, request):
+ name = request.param
+
+ if name == "mapped_column":
+ yield mapped_column(default=7, init=True)
+ elif name == "synonym":
+ yield synonym("some_int", default=7, init=True)
+ elif name == "deferred":
+ yield deferred(Column(Integer), default=7, init=True)
+ elif name == "column_property":
+ yield column_property(Column(Integer), default=7, init=True)
+
+ def test_attrs_rejected_if_not_a_dc(
+ self, mapped_expr_constructor, decl_base: Type[DeclarativeBase]
+ ):
+ if isinstance(mapped_expr_constructor, MappedColumn):
+ unwanted_args = "'init'"
+ else:
+ unwanted_args = "'default', 'init'"
+ with expect_raises_message(
+ exc.ArgumentError,
+ r"Attribute 'x' on class .*A.* includes dataclasses "
+ r"argument\(s\): "
+ rf"{unwanted_args} but class does not specify SQLAlchemy native "
+ "dataclass configuration",
+ ):
+
+ class A(decl_base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+
+ x: Mapped[int] = mapped_expr_constructor
+
+ def _assert_cls(self, cls, dc_arguments):
+
+ if dc_arguments["init"]:
+
+ def create(data, x):
+ return cls(data, x)
+
+ else:
+
+ def create(data, x):
+ a1 = cls()
+ a1.data = data
+ a1.x = x
+ return a1
+
+ for n in self.dc_arg_names:
+ if dc_arguments[n]:
+ getattr(self, f"_assert_{n}")(cls, create, dc_arguments)
+ else:
+ getattr(self, f"_assert_not_{n}")(cls, create, dc_arguments)
+
+ if dc_arguments["init"]:
+ a1 = cls("some data")
+ eq_(a1.x, 7)
+
+ a1 = create("some data", 15)
+ some_int = a1.some_int
+ eq_(
+ dataclasses.asdict(a1),
+ {"data": "some data", "id": None, "some_int": some_int, "x": 15},
+ )
+ eq_(dataclasses.astuple(a1), (None, "some data", some_int, 15))
+
+ def _assert_unsafe_hash(self, cls, create, dc_arguments):
+ a1 = create("d1", 5)
+ hash(a1)
+
+ def _assert_not_unsafe_hash(self, cls, create, dc_arguments):
+ a1 = create("d1", 5)
+
+ if dc_arguments["eq"]:
+ with expect_raises(TypeError):
+ hash(a1)
+ else:
+ hash(a1)
+
+ def _assert_eq(self, cls, create, dc_arguments):
+ a1 = create("d1", 5)
+ a2 = create("d2", 10)
+ a3 = create("d1", 5)
+
+ eq_(a1, a3)
+ ne_(a1, a2)
+
+ def _assert_not_eq(self, cls, create, dc_arguments):
+ a1 = create("d1", 5)
+ a2 = create("d2", 10)
+ a3 = create("d1", 5)
+
+ eq_(a1, a1)
+ ne_(a1, a3)
+ ne_(a1, a2)
+
+ def _assert_order(self, cls, create, dc_arguments):
+ is_false(create("g", 10) < create("b", 7))
+
+ is_true(create("g", 10) > create("b", 7))
+
+ is_false(create("g", 10) <= create("b", 7))
+
+ is_true(create("g", 10) >= create("b", 7))
+
+ eq_(
+ list(sorted([create("g", 10), create("g", 5), create("b", 7)])),
+ [
+ create("b", 7),
+ create("g", 5),
+ create("g", 10),
+ ],
+ )
+
+ def _assert_not_order(self, cls, create, dc_arguments):
+ with expect_raises(TypeError):
+ create("g", 10) < create("b", 7)
+
+ with expect_raises(TypeError):
+ create("g", 10) > create("b", 7)
+
+ with expect_raises(TypeError):
+ create("g", 10) <= create("b", 7)
+
+ with expect_raises(TypeError):
+ create("g", 10) >= create("b", 7)
+
+ def _assert_repr(self, cls, create, dc_arguments):
+ a1 = create("some data", 12)
+ eq_regex(repr(a1), r".*A\(id=None, data='some data', x=12\)")
+
+ def _assert_not_repr(self, cls, create, dc_arguments):
+ a1 = create("some data", 12)
+ eq_regex(repr(a1), r"<.*A object at 0x.*>")
+
+ def _assert_init(self, cls, create, dc_arguments):
+ a1 = cls("some data", 5)
+
+ eq_(a1.data, "some data")
+ eq_(a1.x, 5)
+
+ a2 = cls(data="some data", x=5)
+ eq_(a2.data, "some data")
+ eq_(a2.x, 5)
+
+ a3 = cls(data="some data")
+ eq_(a3.data, "some data")
+ eq_(a3.x, 7)
+
+ def _assert_not_init(self, cls, create, dc_arguments):
+
+ with expect_raises(TypeError):
+ cls("Some data", 5)
+
+ # we run real "dataclasses" on the class. so with init=False, it
+ # doesn't touch what was there, and the SQLA default constructor
+ # gets put on.
+ a1 = cls(data="some data")
+ eq_(a1.data, "some data")
+ eq_(a1.x, None)
+
+ a1 = cls()
+ eq_(a1.data, None)
+
+ # no constructor, it sets None for x...ok
+ eq_(a1.x, None)
+
+ def test_dc_arguments_decorator(
+ self,
+ dc_argument_fixture,
+ mapped_expr_constructor,
+ registry: _RegistryType,
+ ):
+ @registry.mapped_as_dataclass(**dc_argument_fixture)
+ class A:
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ some_int: Mapped[int] = mapped_column(init=False, repr=False)
+
+ x: Mapped[Optional[int]] = mapped_expr_constructor
+
+ self._assert_cls(A, dc_argument_fixture)
+
+ def test_dc_arguments_base(
+ self,
+ dc_argument_fixture,
+ mapped_expr_constructor,
+ registry: _RegistryType,
+ ):
+ reg = registry
+
+ class Base(MappedAsDataclass, DeclarativeBase, **dc_argument_fixture):
+ registry = reg
+
+ class A(Base):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ some_int: Mapped[int] = mapped_column(init=False, repr=False)
+
+ x: Mapped[Optional[int]] = mapped_expr_constructor
+
+ self.A = A
+
+ def test_dc_arguments_perclass(
+ self,
+ dc_argument_fixture,
+ mapped_expr_constructor,
+ decl_base: Type[DeclarativeBase],
+ ):
+ class A(MappedAsDataclass, decl_base, **dc_argument_fixture):
+ __tablename__ = "a"
+
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ data: Mapped[str]
+
+ some_int: Mapped[int] = mapped_column(init=False, repr=False)
+
+ x: Mapped[Optional[int]] = mapped_expr_constructor
+
+ self.A = A
+
+
+class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
+ __dialect__ = "default"
+
+ def test_composite_setup(self, dc_decl_base: Type[MappedAsDataclass]):
+ @dataclasses.dataclass
+ class Point:
+ x: int
+ y: int
+
+ class Edge(dc_decl_base):
+ __tablename__ = "edge"
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+ graph_id: Mapped[int] = mapped_column(
+ ForeignKey("graph.id"), init=False
+ )
+
+ start: Mapped[Point] = composite(
+ Point, mapped_column("x1"), mapped_column("y1"), default=None
+ )
+
+ end: Mapped[Point] = composite(
+ Point, mapped_column("x2"), mapped_column("y2"), default=None
+ )
+
+ class Graph(dc_decl_base):
+ __tablename__ = "graph"
+ id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+ edges: Mapped[List[Edge]] = relationship()
+
+ Point.__qualname__ = "mymodel.Point"
+ Edge.__qualname__ = "mymodel.Edge"
+ Graph.__qualname__ = "mymodel.Graph"
+ g = Graph(
+ edges=[
+ Edge(start=Point(1, 2), end=Point(3, 4)),
+ Edge(start=Point(7, 8), end=Point(5, 6)),
+ ]
+ )
+ eq_(
+ repr(g),
+ "mymodel.Graph(id=None, edges=[mymodel.Edge(id=None, "
+ "graph_id=None, start=mymodel.Point(x=1, y=2), "
+ "end=mymodel.Point(x=3, y=4)), "
+ "mymodel.Edge(id=None, graph_id=None, "
+ "start=mymodel.Point(x=7, y=8), end=mymodel.Point(x=5, y=6))])",
+ )
+
+ def test_named_setup(self, dc_decl_base: Type[MappedAsDataclass]):
+ @dataclasses.dataclass
+ class Address:
+ street: str
+ state: str
+ zip_: str
+
+ class User(dc_decl_base):
+ __tablename__ = "user"
+
+ id: Mapped[int] = mapped_column(
+ primary_key=True, init=False, repr=False
+ )
+ name: Mapped[str] = mapped_column()
+
+ address: Mapped[Address] = composite(
+ Address,
+ mapped_column(),
+ mapped_column(),
+ mapped_column("zip"),
+ default=None,
+ )
+
+ Address.__qualname__ = "mymodule.Address"
+ User.__qualname__ = "mymodule.User"
+ u = User(
+ name="user 1",
+ address=Address("123 anywhere street", "NY", "12345"),
+ )
+ u2 = User("u2")
+ eq_(
+ repr(u),
+ "mymodule.User(name='user 1', "
+ "address=mymodule.Address(street='123 anywhere street', "
+ "state='NY', zip_='12345'))",
+ )
+ eq_(repr(u2), "mymodule.User(name='u2', address=None)")
is_true(User.__table__.c.data.nullable)
assert isinstance(User.__table__.c.created_at.type, DateTime)
+ def test_column_default(self, decl_base):
+ class MyClass(decl_base):
+ __tablename__ = "mytable"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ data: Mapped[str] = mapped_column(default="some default")
+
+ mc = MyClass()
+ assert "data" not in mc.__dict__
+
+ eq_(MyClass.__table__.c.data.default.arg, "some default")
+
def test_anno_w_fixed_table(self, decl_base):
users = Table(
"users",
with expect_raises_message(
ArgumentError,
r"Type annotation for \"User.address\" should use the syntax "
- r"\"Mapped\['Address'\]\" or \"MappedColumn\['Address'\]\"",
+ r"\"Mapped\['Address'\]\"",
):
class User(decl_base):
# round trip!
eq_(u1.address, Address("123 anywhere street", "NY", "12345"))
+ def test_cls_annotated_no_mapped_cols_setup(self, decl_base):
+ @dataclasses.dataclass
+ class Address:
+ street: str
+ state: str
+ zip_: str
+
+ class User(decl_base):
+ __tablename__ = "user"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str] = mapped_column()
+
+ address: Mapped[Address] = composite()
+
+ decl_base.metadata.create_all(testing.db)
+
+ with fixture_session() as sess:
+ sess.add(
+ User(
+ name="user 1",
+ address=Address("123 anywhere street", "NY", "12345"),
+ )
+ )
+ sess.commit()
+
+ with fixture_session() as sess:
+ u1 = sess.scalar(select(User))
+
+ # round trip!
+ eq_(u1.address, Address("123 anywhere street", "NY", "12345"))
+
def test_one_col_setup(self, decl_base):
@dataclasses.dataclass
class Address: