From: Mike Bayer Date: Thu, 3 Dec 2020 19:35:30 +0000 (-0500) Subject: Allow Declarative to extract class attr from field X-Git-Tag: rel_1_4_0b2~83^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=9625adba3553803acd5488660d65c8e675a61fa6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Allow Declarative to extract class attr from field Added an alternate resolution scheme to Declarative that will extract the SQLAlchemy column or mapped property from the "metadata" dictionary of a dataclasses.Field object. This allows full declarative mappings to be combined with dataclass fields. Fixes: #5745 Change-Id: I1165bc025246a4cb9fc099b1b7c46a6b0f799b23 --- diff --git a/doc/build/changelog/unreleased_14/5745.rst b/doc/build/changelog/unreleased_14/5745.rst new file mode 100644 index 0000000000..30788604ce --- /dev/null +++ b/doc/build/changelog/unreleased_14/5745.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: feature, orm, declarative + :tickets: 5745 + + Added an alternate resolution scheme to Declarative that will extract the + SQLAlchemy column or mapped property from the "metadata" dictionary of a + dataclasses.Field object. This allows full declarative mappings to be + combined with dataclass fields. + + .. seealso:: + + :ref:`orm_declarative_dataclasses_declarative_table` \ No newline at end of file diff --git a/doc/build/orm/mapping_styles.rst b/doc/build/orm/mapping_styles.rst index 1cd742b545..d260724d85 100644 --- a/doc/build/orm/mapping_styles.rst +++ b/doc/build/orm/mapping_styles.rst @@ -31,9 +31,13 @@ The full suite of styles can be hierarchically organized as follows: * :ref:`orm_declarative_table` * :ref:`Imperative Table (a.k.a. "hybrid table") ` * Using :meth:`_orm.registry.mapped` Declarative Decorator - * Declarative Table - * Imperative Table (Hybrid) - * :ref:`orm_declarative_dataclasses` + * :ref:`Declarative Table ` - combine :meth:`_orm.registry.mapped` + with ``__tablename__`` + * Imperative Table (Hybrid) - combine :meth:`_orm.registry.mapped` with ``__table__`` + * :ref:`orm_declarative_dataclasses` + * :ref:`orm_declarative_dataclasses_imperative_table` + * :ref:`orm_declarative_dataclasses_declarative_table` + * :ref:`orm_declarative_attrs_imperative_table` * :ref:`Imperative (a.k.a. "classical" mapping) ` * Using :meth:`_orm.registry.map_imperatively` * :ref:`orm_imperative_dataclasses` @@ -198,13 +202,14 @@ ORM mapping process proceeds via the :meth:`_orm.registry.mapped` decorator or via the :meth:`_orm.registry.map_imperatively` method discussed in a later section. -As the attributes set up for ``@dataclass`` or ``@attr.s`` are typically those -which will be matched up to the :class:`_schema.Column` objects that are -mapped, it is usually required that the -:ref:`orm_imperative_table_configuration` style is used in order to configure +Mapping with ``@dataclass`` or ``@attr.s`` may be used in a straightforward +way with :ref:`orm_imperative_table_configuration` style, where the the :class:`_schema.Table`, which means that it is defined separately and -associated with the class via the ``__table__``. +associated with the class via the ``__table__``. For dataclasses specifically, +:ref:`orm_declarative_table` is also supported. +.. versionadded:: 1.4.0b2 Added support for full declarative mapping when using + dataclasses. When attributes are defined using ``dataclasses``, the ``@dataclass`` decorator consumes them but leaves them in place on the class. @@ -223,7 +228,13 @@ mapping process takes over these attributes without any issue. than skipping them as is the default behavior for any class attribute that's not part of the mapping. -An example of a mapping using ``@dataclass`` is as follows:: +.. _orm_declarative_dataclasses_imperative_table: + +Example One - Dataclasses with Imperative Table +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An example of a mapping using ``@dataclass`` using +:ref:`orm_imperative_table_configuration` is as follows:: from __future__ import annotations @@ -288,7 +299,82 @@ during flush from autoincrement or other default value generator. To allow them to be specified in the constructor explicitly, they would instead be given a default value of ``None``. -Similarly, a mapping using ``@attr.s``:: +For a :func:`_orm.relationship` to be declared separately, it needs to +be specified directly within the :paramref:`_orm.mapper.properties` +dictionary passed to the :func:`_orm.mapper`. An alternative to this +approach is in the next example. + +.. _orm_declarative_dataclasses_declarative_table: + +Example Two - Dataclasses with Declarative Table +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The fully declarative approach requires that :class:`_schema.Column` objects +are declared as class attributes, which when using dataclasses would conflict +with the dataclass-level attributes. An approach to combine these together +is to make use of the ``metadata`` attribute on the ``dataclass.field`` +object, where SQLAlchemy-specific mapping information may be supplied. +Declarative supports extraction of these parameters when the class +specifies the attribute ``__sa_dataclass_metadata_key__``. This also +provides a more succinct method of indicating the :func:`_orm.relationship` +association:: + + + from __future__ import annotations + + from dataclasses import dataclass + from dataclasses import field + from typing import List + + from sqlalchemy import Column + from sqlalchemy import ForeignKey + from sqlalchemy import Integer + from sqlalchemy import String + from sqlalchemy.orm import registry + from sqlalchemy.orm import relationship + + mapper_registry = registry() + + + @mapper_registry.mapped + @dataclass + class User: + __tablename__ = "user" + + __sa_dataclass_metadata_key__ = "sa" + id: int = field( + init=False, metadata={"sa": Column(Integer, primary_key=True)} + ) + name: str = field(default=None, metadata={"sa": Column(String(50))}) + fullname: str = field(default=None, metadata={"sa": Column(String(50))}) + nickname: str = field(default=None, metadata={"sa": Column(String(12))}) + addresses: List[Address] = field( + default_factory=list, metadata={"sa": relationship("Address")} + ) + + + @mapper_registry.mapped + @dataclass + class Address: + __tablename__ = "address" + __sa_dataclass_metadata_key__ = "sa" + id: int = field( + init=False, metadata={"sa": Column(Integer, primary_key=True)} + ) + user_id: int = field( + init=False, metadata={"sa": Column(ForeignKey("user.id"))} + ) + email_address: str = field( + default=None, metadata={"sa": Column(String(50))} + ) + + +.. _orm_declarative_attrs_imperative_table: + +Example Three - attrs with Imperative Table +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A mapping using ``@attr.s``, in conjunction with imperative table:: import attr diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 8da326b0e6..353f44e432 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -334,6 +334,9 @@ class _ClassScanMapperConfig(_MapperConfig): tablename = None for base in cls.__mro__: + + sa_dataclass_metadata_key = None + class_mapped = ( base is not cls and _declared_mapping_info(base) is not None @@ -342,10 +345,25 @@ class _ClassScanMapperConfig(_MapperConfig): ) ) + if sa_dataclass_metadata_key is None: + sa_dataclass_metadata_key = _get_immediate_cls_attr( + base, "__sa_dataclass_metadata_key__", None + ) + + def attributes_for_class(cls): + for name, obj in vars(cls).items(): + yield name, obj + if sa_dataclass_metadata_key: + for field in util.dataclass_fields(cls): + if sa_dataclass_metadata_key in field.metadata: + yield field.name, field.metadata[ + sa_dataclass_metadata_key + ] + if not class_mapped and base is not cls: - self._produce_column_copies(base) + self._produce_column_copies(attributes_for_class, base) - for name, obj in vars(base).items(): + for name, obj in attributes_for_class(base): if name == "__mapper_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -452,6 +470,8 @@ class _ClassScanMapperConfig(_MapperConfig): # however, check for some more common mistakes else: self._warn_for_decl_attributes(base, name, obj) + elif name not in dict_ or dict_[name] is not obj: + dict_[name] = obj if inherited_table_args and not tablename: table_args = None @@ -469,12 +489,12 @@ class _ClassScanMapperConfig(_MapperConfig): % (key, cls) ) - def _produce_column_copies(self, base): + def _produce_column_copies(self, attributes_for_class, base): cls = self.cls dict_ = self.dict_ column_copies = self.column_copies # copy mixin columns to the mapped class - for name, obj in vars(base).items(): + for name, obj in attributes_for_class(base): if isinstance(obj, Column): if getattr(cls, name) is not obj: # if column has been overridden diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e501838940..e8f98d1506 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -56,12 +56,6 @@ from ..sql import util as sql_util from ..sql import visitors from ..util import HasMemoized -try: - import dataclasses -except ImportError: - # The dataclasses module was added in Python 3.7 - dataclasses = None - _mapper_registry = weakref.WeakKeyDictionary() _already_compiling = False @@ -2645,10 +2639,7 @@ class Mapper( @HasMemoized.memoized_attribute def _dataclass_fields(self): - if dataclasses is None or not dataclasses.is_dataclass(self.class_): - return frozenset() - - return {field.name for field in dataclasses.fields(self.class_)} + return [f.name for f in util.dataclass_fields(self.class_)] def _should_exclude(self, name, assigned_name, local, column): """determine whether a particular property should be implicitly diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index f4363d03ce..2e3f687229 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -57,6 +57,7 @@ from .compat import byte_buffer # noqa from .compat import callable # noqa from .compat import cmp # noqa from .compat import cpython # noqa +from .compat import dataclass_fields # noqa from .compat import decode_backslashreplace # noqa from .compat import dottedgetter # noqa from .compat import has_refcount_gc # noqa diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index e8c4880478..77c913640b 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -421,6 +421,22 @@ else: import collections as collections_abc # noqa +if py37: + import dataclasses + + def dataclass_fields(cls): + if dataclasses.is_dataclass(cls): + return dataclasses.fields(cls) + else: + return [] + + +else: + + def dataclass_fields(cls): + return [] + + def raise_from_cause(exception, exc_info=None): r"""legacy. use raise\_()""" diff --git a/test/orm/test_dataclasses_py3k.py b/test/orm/test_dataclasses_py3k.py index d3f9530724..a4b5e4c83d 100644 --- a/test/orm/test_dataclasses_py3k.py +++ b/test/orm/test_dataclasses_py3k.py @@ -288,3 +288,99 @@ class PlainDeclarativeDataclassesTest(DataclassesTest): @classmethod def setup_mappers(cls): pass + + +class FieldEmbeddedDeclarativeDataclassesTest( + fixtures.DeclarativeMappedTest, DataclassesTest +): + __requires__ = ("dataclasses",) + + @classmethod + def setup_classes(cls): + declarative = cls.DeclarativeBasic.registry.mapped + + @declarative + @dataclasses.dataclass + class Widget: + __tablename__ = "widgets" + __sa_dataclass_metadata_key__ = "sa" + + widget_id = Column(Integer, primary_key=True) + account_id = Column( + Integer, + ForeignKey("accounts.account_id"), + nullable=False, + ) + type = Column(String(30), nullable=False) + + name: Optional[str] = dataclasses.field( + default=None, + metadata={"sa": Column(String(30), nullable=False)}, + ) + __mapper_args__ = dict( + polymorphic_on="type", + polymorphic_identity="normal", + ) + + @declarative + @dataclasses.dataclass + class SpecialWidget(Widget): + __sa_dataclass_metadata_key__ = "sa" + + magic: bool = dataclasses.field( + default=False, metadata={"sa": Column(Boolean)} + ) + + __mapper_args__ = dict( + polymorphic_identity="special", + ) + + @declarative + @dataclasses.dataclass + class Account: + __tablename__ = "accounts" + __sa_dataclass_metadata_key__ = "sa" + + account_id: int = dataclasses.field( + metadata={"sa": Column(Integer, primary_key=True)}, + ) + widgets: List[Widget] = dataclasses.field( + default_factory=list, metadata={"sa": relationship("Widget")} + ) + widget_count: int = dataclasses.field( + init=False, + metadata={ + "sa": Column("widget_count", Integer, nullable=False) + }, + ) + + def __post_init__(self): + self.widget_count = len(self.widgets) + + def add_widget(self, widget: Widget): + self.widgets.append(widget) + self.widget_count += 1 + + cls.classes.Account = Account + cls.classes.Widget = Widget + cls.classes.SpecialWidget = SpecialWidget + + @classmethod + def setup_mappers(cls): + pass + + @classmethod + def define_tables(cls, metadata): + pass + + def test_asdict_and_astuple(self): + Widget = self.classes.Widget + SpecialWidget = self.classes.SpecialWidget + + widget = Widget("Foo") + eq_(dataclasses.asdict(widget), {"name": "Foo"}) + eq_(dataclasses.astuple(widget), ("Foo",)) + + widget = SpecialWidget("Bar", magic=True) + eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True}) + eq_(dataclasses.astuple(widget), ("Bar", True))