]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow Declarative to extract class attr from field
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Dec 2020 19:35:30 +0000 (14:35 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 19 Dec 2020 17:43:16 +0000 (12:43 -0500)
Added an alternate resolution scheme to Declarative that will extract the
SQLAlchemy column or mapped property from the "metadata" dictionary of a
dataclasses.Field object.  This allows full declarative mappings to be
combined with dataclass fields.

Fixes: #5745
Change-Id: I1165bc025246a4cb9fc099b1b7c46a6b0f799b23

doc/build/changelog/unreleased_14/5745.rst [new file with mode: 0644]
doc/build/orm/mapping_styles.rst
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
test/orm/test_dataclasses_py3k.py

diff --git a/doc/build/changelog/unreleased_14/5745.rst b/doc/build/changelog/unreleased_14/5745.rst
new file mode 100644 (file)
index 0000000..3078860
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: feature, orm, declarative
+    :tickets: 5745
+
+    Added an alternate resolution scheme to Declarative that will extract the
+    SQLAlchemy column or mapped property from the "metadata" dictionary of a
+    dataclasses.Field object.  This allows full declarative mappings to be
+    combined with dataclass fields.
+
+    .. seealso::
+
+        :ref:`orm_declarative_dataclasses_declarative_table`
\ No newline at end of file
index 1cd742b545c72e77a45fc67d048bc45f86f2f849..d260724d854834ff36d36700d1744331b8e72741 100644 (file)
@@ -31,9 +31,13 @@ The full suite of styles can be hierarchically organized as follows:
         * :ref:`orm_declarative_table`
         * :ref:`Imperative Table (a.k.a. "hybrid table") <orm_imperative_table_configuration>`
     * Using :meth:`_orm.registry.mapped` Declarative Decorator
-        * Declarative Table
-        * Imperative Table (Hybrid)
-            * :ref:`orm_declarative_dataclasses`
+        * :ref:`Declarative Table <orm_declarative_decorator>` - combine :meth:`_orm.registry.mapped`
+          with ``__tablename__``
+        * Imperative Table (Hybrid) - combine :meth:`_orm.registry.mapped` with ``__table__``
+        * :ref:`orm_declarative_dataclasses`
+            * :ref:`orm_declarative_dataclasses_imperative_table`
+            * :ref:`orm_declarative_dataclasses_declarative_table`
+            * :ref:`orm_declarative_attrs_imperative_table`
 * :ref:`Imperative (a.k.a. "classical" mapping) <orm_imperative_mapping>`
     * Using :meth:`_orm.registry.map_imperatively`
         * :ref:`orm_imperative_dataclasses`
@@ -198,13 +202,14 @@ ORM mapping process proceeds via the :meth:`_orm.registry.mapped` decorator
 or via the :meth:`_orm.registry.map_imperatively` method discussed in a
 later section.
 
-As the attributes set up for ``@dataclass`` or ``@attr.s`` are typically those
-which will be matched up to the :class:`_schema.Column` objects that are
-mapped, it is usually required that the
-:ref:`orm_imperative_table_configuration` style is used in order to configure
+Mapping with ``@dataclass`` or ``@attr.s`` may be used in a straightforward
+way with :ref:`orm_imperative_table_configuration` style, where the
 the :class:`_schema.Table`, which means that it is defined separately and
-associated with the class via the ``__table__``.
+associated with the class via the ``__table__``.   For dataclasses specifically,
+:ref:`orm_declarative_table` is also supported.
 
+.. versionadded:: 1.4.0b2 Added support for full declarative mapping when using
+   dataclasses.
 
 When attributes are defined using ``dataclasses``, the ``@dataclass``
 decorator consumes them but leaves them in place on the class.
@@ -223,7 +228,13 @@ mapping process takes over these attributes without any issue.
    than skipping them as is the default behavior for any class attribute
    that's not part of the mapping.
 
-An example of a mapping using ``@dataclass`` is as follows::
+.. _orm_declarative_dataclasses_imperative_table:
+
+Example One - Dataclasses with Imperative Table
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+An example of a mapping using ``@dataclass`` using
+:ref:`orm_imperative_table_configuration` is as follows::
 
     from __future__ import annotations
 
@@ -288,7 +299,82 @@ during flush from autoincrement or other default value generator.   To
 allow them to be specified in the constructor explicitly, they would instead
 be given a default value of ``None``.
 
-Similarly, a mapping using ``@attr.s``::
+For a :func:`_orm.relationship` to be declared separately, it needs to
+be specified directly within the :paramref:`_orm.mapper.properties`
+dictionary passed to the :func:`_orm.mapper`.   An alternative to this
+approach is in the next example.
+
+.. _orm_declarative_dataclasses_declarative_table:
+
+Example Two - Dataclasses with Declarative Table
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The fully declarative approach requires that :class:`_schema.Column` objects
+are declared as class attributes, which when using dataclasses would conflict
+with the dataclass-level attributes.  An approach to combine these together
+is to make use of the ``metadata`` attribute on the ``dataclass.field``
+object, where SQLAlchemy-specific mapping information may be supplied.
+Declarative supports extraction of these parameters when the class
+specifies the attribute ``__sa_dataclass_metadata_key__``.  This also
+provides a more succinct method of indicating the :func:`_orm.relationship`
+association::
+
+
+    from __future__ import annotations
+
+    from dataclasses import dataclass
+    from dataclasses import field
+    from typing import List
+
+    from sqlalchemy import Column
+    from sqlalchemy import ForeignKey
+    from sqlalchemy import Integer
+    from sqlalchemy import String
+    from sqlalchemy.orm import registry
+    from sqlalchemy.orm import relationship
+
+    mapper_registry = registry()
+
+
+    @mapper_registry.mapped
+    @dataclass
+    class User:
+        __tablename__ = "user"
+
+        __sa_dataclass_metadata_key__ = "sa"
+        id: int = field(
+            init=False, metadata={"sa": Column(Integer, primary_key=True)}
+        )
+        name: str = field(default=None, metadata={"sa": Column(String(50))})
+        fullname: str = field(default=None, metadata={"sa": Column(String(50))})
+        nickname: str = field(default=None, metadata={"sa": Column(String(12))})
+        addresses: List[Address] = field(
+            default_factory=list, metadata={"sa": relationship("Address")}
+        )
+
+
+    @mapper_registry.mapped
+    @dataclass
+    class Address:
+        __tablename__ = "address"
+        __sa_dataclass_metadata_key__ = "sa"
+        id: int = field(
+            init=False, metadata={"sa": Column(Integer, primary_key=True)}
+        )
+        user_id: int = field(
+            init=False, metadata={"sa": Column(ForeignKey("user.id"))}
+        )
+        email_address: str = field(
+            default=None, metadata={"sa": Column(String(50))}
+        )
+
+
+.. _orm_declarative_attrs_imperative_table:
+
+Example Three - attrs with Imperative Table
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+A mapping using ``@attr.s``, in conjunction with imperative table::
 
     import attr
 
index 8da326b0e6d8a8a3e6c3c81f72d56fc7d824e44e..353f44e43269f6cc01b996d04631e6099f80c269 100644 (file)
@@ -334,6 +334,9 @@ class _ClassScanMapperConfig(_MapperConfig):
         tablename = None
 
         for base in cls.__mro__:
+
+            sa_dataclass_metadata_key = None
+
             class_mapped = (
                 base is not cls
                 and _declared_mapping_info(base) is not None
@@ -342,10 +345,25 @@ class _ClassScanMapperConfig(_MapperConfig):
                 )
             )
 
+            if sa_dataclass_metadata_key is None:
+                sa_dataclass_metadata_key = _get_immediate_cls_attr(
+                    base, "__sa_dataclass_metadata_key__", None
+                )
+
+            def attributes_for_class(cls):
+                for name, obj in vars(cls).items():
+                    yield name, obj
+                if sa_dataclass_metadata_key:
+                    for field in util.dataclass_fields(cls):
+                        if sa_dataclass_metadata_key in field.metadata:
+                            yield field.name, field.metadata[
+                                sa_dataclass_metadata_key
+                            ]
+
             if not class_mapped and base is not cls:
-                self._produce_column_copies(base)
+                self._produce_column_copies(attributes_for_class, base)
 
-            for name, obj in vars(base).items():
+            for name, obj in attributes_for_class(base):
                 if name == "__mapper_args__":
                     check_decl = _check_declared_props_nocascade(
                         obj, name, cls
@@ -452,6 +470,8 @@ class _ClassScanMapperConfig(_MapperConfig):
                     # however, check for some more common mistakes
                     else:
                         self._warn_for_decl_attributes(base, name, obj)
+                elif name not in dict_ or dict_[name] is not obj:
+                    dict_[name] = obj
 
         if inherited_table_args and not tablename:
             table_args = None
@@ -469,12 +489,12 @@ class _ClassScanMapperConfig(_MapperConfig):
                 % (key, cls)
             )
 
-    def _produce_column_copies(self, base):
+    def _produce_column_copies(self, attributes_for_class, base):
         cls = self.cls
         dict_ = self.dict_
         column_copies = self.column_copies
         # copy mixin columns to the mapped class
-        for name, obj in vars(base).items():
+        for name, obj in attributes_for_class(base):
             if isinstance(obj, Column):
                 if getattr(cls, name) is not obj:
                     # if column has been overridden
index e501838940bcfedcd6146db2d2fcc8e3be4985a0..e8f98d1506f6c6c6b20a0c09dd1cc26d1281a79d 100644 (file)
@@ -56,12 +56,6 @@ from ..sql import util as sql_util
 from ..sql import visitors
 from ..util import HasMemoized
 
-try:
-    import dataclasses
-except ImportError:
-    # The dataclasses module was added in Python 3.7
-    dataclasses = None
-
 
 _mapper_registry = weakref.WeakKeyDictionary()
 _already_compiling = False
@@ -2645,10 +2639,7 @@ class Mapper(
 
     @HasMemoized.memoized_attribute
     def _dataclass_fields(self):
-        if dataclasses is None or not dataclasses.is_dataclass(self.class_):
-            return frozenset()
-
-        return {field.name for field in dataclasses.fields(self.class_)}
+        return [f.name for f in util.dataclass_fields(self.class_)]
 
     def _should_exclude(self, name, assigned_name, local, column):
         """determine whether a particular property should be implicitly
index f4363d03cef7d3c55a4fb10459481b3da623664f..2e3f687229723999332f2240359755e045658081 100644 (file)
@@ -57,6 +57,7 @@ from .compat import byte_buffer  # noqa
 from .compat import callable  # noqa
 from .compat import cmp  # noqa
 from .compat import cpython  # noqa
+from .compat import dataclass_fields  # noqa
 from .compat import decode_backslashreplace  # noqa
 from .compat import dottedgetter  # noqa
 from .compat import has_refcount_gc  # noqa
index e8c4880478db642664cb05f3040720f54c436eed..77c913640b516a616f54d0a2dec40a6082fdedf4 100644 (file)
@@ -421,6 +421,22 @@ else:
     import collections as collections_abc  # noqa
 
 
+if py37:
+    import dataclasses
+
+    def dataclass_fields(cls):
+        if dataclasses.is_dataclass(cls):
+            return dataclasses.fields(cls)
+        else:
+            return []
+
+
+else:
+
+    def dataclass_fields(cls):
+        return []
+
+
 def raise_from_cause(exception, exc_info=None):
     r"""legacy.  use raise\_()"""
 
index d3f9530724afa2600e3de786d2674bcba4dbe13f..a4b5e4c83d5f0410bec218285d734627c09fca57 100644 (file)
@@ -288,3 +288,99 @@ class PlainDeclarativeDataclassesTest(DataclassesTest):
     @classmethod
     def setup_mappers(cls):
         pass
+
+
+class FieldEmbeddedDeclarativeDataclassesTest(
+    fixtures.DeclarativeMappedTest, DataclassesTest
+):
+    __requires__ = ("dataclasses",)
+
+    @classmethod
+    def setup_classes(cls):
+        declarative = cls.DeclarativeBasic.registry.mapped
+
+        @declarative
+        @dataclasses.dataclass
+        class Widget:
+            __tablename__ = "widgets"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            widget_id = Column(Integer, primary_key=True)
+            account_id = Column(
+                Integer,
+                ForeignKey("accounts.account_id"),
+                nullable=False,
+            )
+            type = Column(String(30), nullable=False)
+
+            name: Optional[str] = dataclasses.field(
+                default=None,
+                metadata={"sa": Column(String(30), nullable=False)},
+            )
+            __mapper_args__ = dict(
+                polymorphic_on="type",
+                polymorphic_identity="normal",
+            )
+
+        @declarative
+        @dataclasses.dataclass
+        class SpecialWidget(Widget):
+            __sa_dataclass_metadata_key__ = "sa"
+
+            magic: bool = dataclasses.field(
+                default=False, metadata={"sa": Column(Boolean)}
+            )
+
+            __mapper_args__ = dict(
+                polymorphic_identity="special",
+            )
+
+        @declarative
+        @dataclasses.dataclass
+        class Account:
+            __tablename__ = "accounts"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            account_id: int = dataclasses.field(
+                metadata={"sa": Column(Integer, primary_key=True)},
+            )
+            widgets: List[Widget] = dataclasses.field(
+                default_factory=list, metadata={"sa": relationship("Widget")}
+            )
+            widget_count: int = dataclasses.field(
+                init=False,
+                metadata={
+                    "sa": Column("widget_count", Integer, nullable=False)
+                },
+            )
+
+            def __post_init__(self):
+                self.widget_count = len(self.widgets)
+
+            def add_widget(self, widget: Widget):
+                self.widgets.append(widget)
+                self.widget_count += 1
+
+        cls.classes.Account = Account
+        cls.classes.Widget = Widget
+        cls.classes.SpecialWidget = SpecialWidget
+
+    @classmethod
+    def setup_mappers(cls):
+        pass
+
+    @classmethod
+    def define_tables(cls, metadata):
+        pass
+
+    def test_asdict_and_astuple(self):
+        Widget = self.classes.Widget
+        SpecialWidget = self.classes.SpecialWidget
+
+        widget = Widget("Foo")
+        eq_(dataclasses.asdict(widget), {"name": "Foo"})
+        eq_(dataclasses.astuple(widget), ("Foo",))
+
+        widget = SpecialWidget("Bar", magic=True)
+        eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True})
+        eq_(dataclasses.astuple(widget), ("Bar", True))