]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add support for classical mapping of dataclasses, Fixes: #5027 5516/head
authorVáclav Klusák <vaclav.klusak@maptiler.com>
Sun, 16 Aug 2020 14:40:42 +0000 (16:40 +0200)
committerVáclav Klusák <vaclav.klusak@klokantech.com>
Mon, 17 Aug 2020 15:44:28 +0000 (17:44 +0200)
- Ignore simple default value class attributes.
- Add tests for classical mapping of dataclasses.
- Document limitations and show example.

doc/build/orm/mapping_styles.rst
lib/sqlalchemy/orm/mapper.py
test/orm/test_dataclasses_py3k.py [new file with mode: 0644]

index f76e4521161b899517fed49be3ea9187652d1c10..c5dae19ab28daf1831e4883ca592c08e770a043b 100644 (file)
@@ -120,6 +120,86 @@ user-defined class, linked together with a :func:`.mapper`.  When we talk about
 "the behavior of :func:`.mapper`", this includes when using the Declarative system
 as well - it's still used, just behind the scenes.
 
+Mapping dataclasses and attrs
+-----------------------------
+
+The dataclasses_ module, added in Python 3.7, provides a ``dataclass`` class
+decorator to automatically generate boilerplate definitions of ``__init__()``,
+``__eq__()``, ``__repr()__``, etc. methods. Another very popular library that does
+the same, and much more, is attrs_. Classes defined using either of these can
+be mapped with the following caveats.
+
+.. note::
+
+    * Only classical mapping is possible, not Declarative. Classes inheriting from
+      Declarative base would get processed by SQLAlchemy before being handed to the
+      ``dataclass`` or ``attrs`` decorator, and this would interfere with its function.
+
+    * The ``dataclass`` decorator adds class attributes corresponding to simple default values.
+      This is done mostly as documentation, these attributes are not necessary for the function
+      of any of the generated methods. Mapping replaces these class attributes with property
+      descriptors.
+
+    * Mapping of frozen ``dataclass`` and ``attrs`` classes is not possible, because the
+      machinery used to enforce immutability interferes with loading.
+
+Example::
+
+    from __future__ import annotations
+    from dataclasses import dataclass, field
+    from typing import List
+
+    from sqlalchemy import Column, ForeignKey, Integer, MetaData, String, Table
+    from sqlalchemy.orm import mapper, relationship
+
+    @dataclass
+    class User:
+        id: int = field(init=False)
+        name: str = None
+        fullname: str = None
+        nickname: str = None
+        addresses: List[Address] = field(default_factory=list)
+
+    @dataclass
+    class Address:
+        id: int = field(init=False)
+        user_id: int = field(init=False)
+        email_address: str = None
+
+    metadata = MetaData()
+
+    user = Table(
+        'user',
+        metadata,
+        Column('id', Integer, primary_key=True),
+        Column('name', String(50)),
+        Column('fullname', String(50)),
+        Column('nickname', String(12)),
+    )
+
+    address = Table(
+        'address',
+        metadata,
+        Column('id', Integer, primary_key=True),
+        Column('user_id', Integer, ForeignKey('user.id')),
+        Column('email_address', String(50)),
+    )
+
+    mapper(User, user, properties={
+        'addresses': relationship(Address, backref='user', order_by=address.c.id),
+    })
+
+    mapper(Address, address)
+
+Note that ``User.id``, ``Address.id``, and ``Address.user_id`` are defined as ``field(init=False)``.
+This means that parameters for these won't be added to ``__init__()`` methods, but
+:class:`.Session` will still be able to set them after getting their values during flush
+from autoincrement or other default value generator. You can also give them a
+``None`` default value instead if you want to be able to specify their values in the constructor.
+
+.. _dataclasses: https://docs.python.org/3/library/dataclasses.html
+.. _attrs: https://www.attrs.org/en/stable/
+
 Runtime Introspection of Mappings, Objects
 ==========================================
 
index 446f6790ed267a67d9d18ba49252e31a8029e256..b21cd9076aef608686d52e7c9a58bee0bd6b3832 100644 (file)
@@ -22,6 +22,12 @@ import sys
 import types
 import weakref
 
+try:
+    import dataclasses
+except ImportError:
+    # The dataclasses module was added in Python 3.7
+    dataclasses = None
+
 from . import attributes
 from . import exc as orm_exc
 from . import instrumentation
@@ -2645,6 +2651,20 @@ class Mapper(
         else:
             return True
 
+    @HasMemoized.memoized_attribute
+    def _is_dataclass(self):
+        if dataclasses is not None:
+            return dataclasses.is_dataclass(self.class_)
+        else:
+            return False
+
+    @HasMemoized.memoized_attribute
+    def _dataclass_fields(self):
+        return set(field.name for field in dataclasses.fields(self.class_))
+
+    def _is_dataclass_field(self, assigned_name):
+        return self._is_dataclass and assigned_name in self._dataclass_fields
+
     def _should_exclude(self, name, assigned_name, local, column):
         """determine whether a particular property should be implicitly
         present on the class.
@@ -2656,16 +2676,21 @@ class Mapper(
 
         # check for class-bound attributes and/or descriptors,
         # either local or from an inherited class
+        # ignore dataclass field default values
         if local:
             if self.class_.__dict__.get(
                 assigned_name, None
             ) is not None and self._is_userland_descriptor(
                 self.class_.__dict__[assigned_name]
-            ):
+            ) and not self._is_dataclass_field(assigned_name):
                 return True
         else:
             attr = self.class_manager._get_class_attr_mro(assigned_name, None)
-            if attr is not None and self._is_userland_descriptor(attr):
+            if (
+                attr is not None
+                and self._is_userland_descriptor(attr)
+                and not self._is_dataclass_field(assigned_name)
+            ):
                 return True
 
         if (
diff --git a/test/orm/test_dataclasses_py3k.py b/test/orm/test_dataclasses_py3k.py
new file mode 100644 (file)
index 0000000..6ffe745
--- /dev/null
@@ -0,0 +1,232 @@
+from typing import List, Optional
+
+from sqlalchemy import Boolean
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy.orm import mapper
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+from sqlalchemy import testing
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import Table
+
+import pytest
+
+
+# The dataclasses module was added in Python 3.7
+dataclasses = pytest.importorskip("dataclasses")
+
+
+class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
+    run_setup_classes = "each"
+    run_setup_mappers = "each"
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "accounts",
+            metadata,
+            Column("account_id", Integer, primary_key=True),
+            Column("widget_count", Integer, nullable=False),
+        )
+        Table(
+            "widgets",
+            metadata,
+            Column("widget_id", Integer, primary_key=True),
+            Column(
+                "account_id",
+                Integer,
+                ForeignKey("accounts.account_id"),
+                nullable=False,
+            ),
+            Column("type", String, nullable=False),
+            Column("name", String, nullable=False),
+            Column("magic", Boolean),
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        @dataclasses.dataclass
+        class Widget:
+            name: Optional[str] = None
+
+        @dataclasses.dataclass
+        class SpecialWidget(Widget):
+            magic: bool = False
+
+        @dataclasses.dataclass
+        class Account:
+            account_id: int
+            widgets: List[Widget] = dataclasses.field(default_factory=list)
+            widget_count: int = dataclasses.field(init=False)
+
+            def __post_init__(self):
+                self.widget_count = len(self.widgets)
+
+            def add_widget(self, widget: Widget):
+                self.widgets.append(widget)
+                self.widget_count += 1
+
+        cls.classes.Account = Account
+        cls.classes.Widget = Widget
+        cls.classes.SpecialWidget = SpecialWidget
+
+    @classmethod
+    def setup_mappers(cls):
+        accounts = cls.tables.accounts
+        widgets = cls.tables.widgets
+
+        Account = cls.classes.Account
+        Widget = cls.classes.Widget
+        SpecialWidget = cls.classes.SpecialWidget
+
+        mapper(
+            Widget,
+            widgets,
+            polymorphic_on=widgets.c.type,
+            polymorphic_identity="normal",
+        )
+        mapper(
+            SpecialWidget,
+            widgets,
+            inherits=Widget,
+            polymorphic_identity="special",
+        )
+        mapper(Account, accounts, properties={
+            "widgets": relationship(Widget),
+        })
+
+    def check_account_dataclass(self, obj):
+        assert dataclasses.is_dataclass(obj)
+        account_id, widgets, widget_count = dataclasses.fields(obj)
+        assert account_id.name == "account_id"
+        assert widget_count.name == "widget_count"
+        assert widgets.name == "widgets"
+
+    def check_widget_dataclass(self, obj):
+        assert dataclasses.is_dataclass(obj)
+        name, = dataclasses.fields(obj)
+        assert name.name == "name"
+
+    def check_special_widget_dataclass(self, obj):
+        assert dataclasses.is_dataclass(obj)
+        name, magic = dataclasses.fields(obj)
+        assert name.name == "name"
+        assert magic.name == "magic"
+
+    def data_fixture(self):
+        Account = self.classes.Account
+        Widget = self.classes.Widget
+        SpecialWidget = self.classes.SpecialWidget
+
+        return Account(
+            account_id=42,
+            widgets=[
+                Widget("Foo"),
+                SpecialWidget("Bar", magic=True),
+            ],
+        )
+
+    def check_data_fixture(self, account):
+        Widget = self.classes.Widget
+        SpecialWidget = self.classes.SpecialWidget
+
+        self.check_account_dataclass(account)
+        assert account.account_id == 42
+        assert account.widget_count == len(account.widgets) == 2
+
+        foo, bar = account.widgets
+
+        self.check_widget_dataclass(foo)
+        assert isinstance(foo, Widget)
+        assert foo.name == "Foo"
+
+        self.check_special_widget_dataclass(bar)
+        assert isinstance(bar, SpecialWidget)
+        assert bar.name == "Bar"
+        assert bar.magic == True
+
+    def test_classes_are_still_dataclasses(self):
+        self.check_account_dataclass(self.classes.Account)
+        self.check_widget_dataclass(self.classes.Widget)
+        self.check_special_widget_dataclass(self.classes.SpecialWidget)
+
+    def test_construction(self):
+        SpecialWidget = self.classes.SpecialWidget
+
+        account = self.data_fixture()
+        self.check_data_fixture(account)
+
+        widget = SpecialWidget()
+        assert widget.name == None
+        assert widget.magic == False
+
+    def test_equality(self):
+        Widget = self.classes.Widget
+        SpecialWidget = self.classes.SpecialWidget
+
+        assert Widget("Foo") == Widget("Foo")
+        assert Widget("Foo") != Widget("Bar")
+        assert Widget("Foo") != SpecialWidget("Foo")
+
+    def test_asdict_and_astuple(self):
+        Widget = self.classes.Widget
+        SpecialWidget = self.classes.SpecialWidget
+
+        widget = Widget("Foo")
+        assert dataclasses.asdict(widget) == {"name": "Foo"}
+        assert dataclasses.astuple(widget) == ("Foo",)
+
+        widget = SpecialWidget("Bar", magic=True)
+        assert dataclasses.asdict(widget) == {
+            "name": "Bar", "magic": True
+        }
+        assert dataclasses.astuple(widget) == ("Bar", True)
+
+    def test_round_trip(self):
+        Account = self.classes.Account
+        account = self.data_fixture()
+
+        session = Session(testing.db)
+        session.add(account)
+        session.commit()
+        session.close()
+
+        a = session.query(Account).get(42)
+        self.check_data_fixture(a)
+
+    def test_appending_to_relationship(self):
+        Account = self.classes.Account
+        Widget = self.classes.Widget
+        account = self.data_fixture()
+
+        session = Session(testing.db)
+        session.add(account)
+        session.commit()
+
+        account.add_widget(Widget("Xyzzy"))
+        session.commit()
+        session.close()
+
+        a = session.query(Account).get(42)
+        assert a.widget_count == len(a.widgets) == 3
+
+    def test_filtering_on_relationship(self):
+        Account = self.classes.Account
+        Widget = self.classes.Widget
+        account = self.data_fixture()
+
+        session = Session(testing.db)
+        session.add(account)
+        session.commit()
+        session.close()
+
+        a = (
+            session.query(Account)
+            .join(Account.widgets)
+            .filter(Widget.name == "Foo")
+            .one()
+        )
+        self.check_data_fixture(a)