]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add support for classical mapping of dataclasses
authorVáclav Klusák <vaclav.klusak@maptiler.com>
Mon, 17 Aug 2020 15:58:56 +0000 (11:58 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 1 Sep 2020 14:59:07 +0000 (10:59 -0400)
Added support for direct mapping of Python classes that are defined using
the Python ``dataclasses`` decorator.    See the section
:ref:`mapping_dataclasses` for background.  Pull request courtesy Václav
Klusák.

Fixes: #5027
Closes: #5516
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5516
Pull-request-sha: bb48c63d1561ca48c954ad9f84a3eb2646571115

Change-Id: Ie33db2aae4adeeb5d99633fe926b9c30bab0b885

doc/build/changelog/unreleased_14/5027.rst [new file with mode: 0644]
doc/build/orm/mapping_styles.rst
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/testing/requirements.py
test/orm/test_dataclasses_py3k.py [new file with mode: 0644]
test/orm/test_mapper.py

diff --git a/doc/build/changelog/unreleased_14/5027.rst b/doc/build/changelog/unreleased_14/5027.rst
new file mode 100644 (file)
index 0000000..6fd2bc9
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 5027
+
+    Added support for direct mapping of Python classes that are defined using
+    the Python ``dataclasses`` decorator.    See the section
+    :ref:`mapping_dataclasses` for background.  Pull request courtesy Václav
+    Klusák.
\ No newline at end of file
index f76e4521161b899517fed49be3ea9187652d1c10..c156f08f1992c1876f3e7f5dd913fe80446ef8c7 100644 (file)
@@ -120,6 +120,88 @@ user-defined class, linked together with a :func:`.mapper`.  When we talk about
 "the behavior of :func:`.mapper`", this includes when using the Declarative system
 as well - it's still used, just behind the scenes.
 
+.. _mapping_dataclasses:
+
+Mapping dataclasses and attrs
+-----------------------------
+
+The dataclasses_ module, added in Python 3.7, provides a ``dataclass`` class
+decorator to automatically generate boilerplate definitions of ``__init__()``,
+``__eq__()``, ``__repr()__``, etc. methods. Another very popular library that does
+the same, and much more, is attrs_. Classes defined using either of these can
+be mapped with the following caveats.
+
+.. versionadded:: 1.4 Added support for direct mapping of Python dataclasses.
+
+The declarative "base" can't be used directly; a mapping function such as
+:func:`_declarative.instrument_declarative` or :func:`_orm.mapper` may be
+used.
+
+The ``dataclass`` decorator adds class attributes corresponding to simple default values.
+This is done mostly as documentation, these attributes are not necessary for the function
+of any of the generated methods. Mapping replaces these class attributes with property
+descriptors.
+
+Mapping of frozen ``dataclass`` and ``attrs`` classes is not possible, because the
+machinery used to enforce immutability interferes with loading.
+
+Example using classical mapping::
+
+    from __future__ import annotations
+    from dataclasses import dataclass, field
+    from typing import List
+
+    from sqlalchemy import Column, ForeignKey, Integer, MetaData, String, Table
+    from sqlalchemy.orm import mapper, relationship
+
+    @dataclass
+    class User:
+        id: int = field(init=False)
+        name: str = None
+        fullname: str = None
+        nickname: str = None
+        addresses: List[Address] = field(default_factory=list)
+
+    @dataclass
+    class Address:
+        id: int = field(init=False)
+        user_id: int = field(init=False)
+        email_address: str = None
+
+    metadata = MetaData()
+
+    user = Table(
+        'user',
+        metadata,
+        Column('id', Integer, primary_key=True),
+        Column('name', String(50)),
+        Column('fullname', String(50)),
+        Column('nickname', String(12)),
+    )
+
+    address = Table(
+        'address',
+        metadata,
+        Column('id', Integer, primary_key=True),
+        Column('user_id', Integer, ForeignKey('user.id')),
+        Column('email_address', String(50)),
+    )
+
+    mapper(User, user, properties={
+        'addresses': relationship(Address, backref='user', order_by=address.c.id),
+    })
+
+    mapper(Address, address)
+
+Note that ``User.id``, ``Address.id``, and ``Address.user_id`` are defined as ``field(init=False)``.
+This means that parameters for these won't be added to ``__init__()`` methods, but
+:class:`.Session` will still be able to set them after getting their values during flush
+from autoincrement or other default value generator. You can also give them a
+``None`` default value instead if you want to be able to specify their values in the constructor.
+
+.. _dataclasses: https://docs.python.org/3/library/dataclasses.html
+.. _attrs: https://www.attrs.org/en/stable/
+
 Runtime Introspection of Mappings, Objects
 ==========================================
 
index 39cf86e34fb59bf17228504b3e927724d8f13bb7..c2efa24a192a49fdaaba578d12ac23f4608c9322 100644 (file)
@@ -56,7 +56,7 @@ class DescriptorProperty(MapperProperty):
 
         if self.descriptor is None:
             desc = getattr(mapper.class_, self.key, None)
-            if mapper._is_userland_descriptor(desc):
+            if mapper._is_userland_descriptor(self.key, desc):
                 self.descriptor = desc
 
         if self.descriptor is None:
index 446f6790ed267a67d9d18ba49252e31a8029e256..755d4afc79b8858ed79351538c24aaa52469925d 100644 (file)
@@ -56,6 +56,12 @@ from ..sql import util as sql_util
 from ..sql import visitors
 from ..util import HasMemoized
 
+try:
+    import dataclasses
+except ImportError:
+    # The dataclasses module was added in Python 3.7
+    dataclasses = None
+
 
 _mapper_registry = weakref.WeakKeyDictionary()
 _already_compiling = False
@@ -2632,7 +2638,7 @@ class Mapper(
 
         return result
 
-    def _is_userland_descriptor(self, obj):
+    def _is_userland_descriptor(self, assigned_name, obj):
         if isinstance(
             obj,
             (
@@ -2643,7 +2649,14 @@ class Mapper(
         ):
             return False
         else:
-            return True
+            return assigned_name not in self._dataclass_fields
+
+    @HasMemoized.memoized_attribute
+    def _dataclass_fields(self):
+        if dataclasses is None or not dataclasses.is_dataclass(self.class_):
+            return frozenset()
+
+        return {field.name for field in dataclasses.fields(self.class_)}
 
     def _should_exclude(self, name, assigned_name, local, column):
         """determine whether a particular property should be implicitly
@@ -2656,16 +2669,19 @@ class Mapper(
 
         # check for class-bound attributes and/or descriptors,
         # either local or from an inherited class
+        # ignore dataclass field default values
         if local:
             if self.class_.__dict__.get(
                 assigned_name, None
             ) is not None and self._is_userland_descriptor(
-                self.class_.__dict__[assigned_name]
+                assigned_name, self.class_.__dict__[assigned_name]
             ):
                 return True
         else:
             attr = self.class_manager._get_class_attr_mro(assigned_name, None)
-            if attr is not None and self._is_userland_descriptor(attr):
+            if attr is not None and self._is_userland_descriptor(
+                assigned_name, attr
+            ):
                 return True
 
         if (
index 9b8caac2efb4e17a34694fcb38e313469c3eb818..4114137d4e9608c4542ceb57a36348d17e725cee 100644 (file)
@@ -1142,6 +1142,10 @@ class SuiteRequirements(Requirements):
             "Python version 3.7 or greater is required.",
         )
 
+    @property
+    def dataclasses(self):
+        return self.python37
+
     @property
     def cpython(self):
         return exclusions.only_if(
diff --git a/test/orm/test_dataclasses_py3k.py b/test/orm/test_dataclasses_py3k.py
new file mode 100644 (file)
index 0000000..1ac97b6
--- /dev/null
@@ -0,0 +1,227 @@
+from typing import List
+from typing import Optional
+
+from sqlalchemy import Boolean
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy.orm import mapper
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import Table
+
+try:
+    import dataclasses
+except ImportError:
+    pass
+
+
+class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
+    __requires__ = ("dataclasses",)
+
+    run_setup_classes = "each"
+    run_setup_mappers = "each"
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "accounts",
+            metadata,
+            Column("account_id", Integer, primary_key=True),
+            Column("widget_count", Integer, nullable=False),
+        )
+        Table(
+            "widgets",
+            metadata,
+            Column("widget_id", Integer, primary_key=True),
+            Column(
+                "account_id",
+                Integer,
+                ForeignKey("accounts.account_id"),
+                nullable=False,
+            ),
+            Column("type", String(30), nullable=False),
+            Column("name", String(30), nullable=False),
+            Column("magic", Boolean),
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        @dataclasses.dataclass
+        class Widget:
+            name: Optional[str] = None
+
+        @dataclasses.dataclass
+        class SpecialWidget(Widget):
+            magic: bool = False
+
+        @dataclasses.dataclass
+        class Account:
+            account_id: int
+            widgets: List[Widget] = dataclasses.field(default_factory=list)
+            widget_count: int = dataclasses.field(init=False)
+
+            def __post_init__(self):
+                self.widget_count = len(self.widgets)
+
+            def add_widget(self, widget: Widget):
+                self.widgets.append(widget)
+                self.widget_count += 1
+
+        cls.classes.Account = Account
+        cls.classes.Widget = Widget
+        cls.classes.SpecialWidget = SpecialWidget
+
+    @classmethod
+    def setup_mappers(cls):
+        accounts = cls.tables.accounts
+        widgets = cls.tables.widgets
+
+        Account = cls.classes.Account
+        Widget = cls.classes.Widget
+        SpecialWidget = cls.classes.SpecialWidget
+
+        mapper(
+            Widget,
+            widgets,
+            polymorphic_on=widgets.c.type,
+            polymorphic_identity="normal",
+        )
+        mapper(
+            SpecialWidget,
+            widgets,
+            inherits=Widget,
+            polymorphic_identity="special",
+        )
+        mapper(Account, accounts, properties={"widgets": relationship(Widget)})
+
+    def check_account_dataclass(self, obj):
+        assert dataclasses.is_dataclass(obj)
+        account_id, widgets, widget_count = dataclasses.fields(obj)
+        eq_(account_id.name, "account_id")
+        eq_(widget_count.name, "widget_count")
+        eq_(widgets.name, "widgets")
+
+    def check_widget_dataclass(self, obj):
+        assert dataclasses.is_dataclass(obj)
+        (name,) = dataclasses.fields(obj)
+        eq_(name.name, "name")
+
+    def check_special_widget_dataclass(self, obj):
+        assert dataclasses.is_dataclass(obj)
+        name, magic = dataclasses.fields(obj)
+        eq_(name.name, "name")
+        eq_(magic.name, "magic")
+
+    def data_fixture(self):
+        Account = self.classes.Account
+        Widget = self.classes.Widget
+        SpecialWidget = self.classes.SpecialWidget
+
+        return Account(
+            account_id=42,
+            widgets=[Widget("Foo"), SpecialWidget("Bar", magic=True)],
+        )
+
+    def check_data_fixture(self, account):
+        Widget = self.classes.Widget
+        SpecialWidget = self.classes.SpecialWidget
+
+        self.check_account_dataclass(account)
+        eq_(account.account_id, 42)
+        eq_(account.widget_count, 2)
+        eq_(len(account.widgets), 2)
+
+        foo, bar = account.widgets
+
+        self.check_widget_dataclass(foo)
+        assert isinstance(foo, Widget)
+        eq_(foo.name, "Foo")
+
+        self.check_special_widget_dataclass(bar)
+        assert isinstance(bar, SpecialWidget)
+        eq_(bar.name, "Bar")
+        eq_(bar.magic, True)
+
+    def test_classes_are_still_dataclasses(self):
+        self.check_account_dataclass(self.classes.Account)
+        self.check_widget_dataclass(self.classes.Widget)
+        self.check_special_widget_dataclass(self.classes.SpecialWidget)
+
+    def test_construction(self):
+        SpecialWidget = self.classes.SpecialWidget
+
+        account = self.data_fixture()
+        self.check_data_fixture(account)
+
+        widget = SpecialWidget()
+        eq_(widget.name, None)
+        eq_(widget.magic, False)
+
+    def test_equality(self):
+        Widget = self.classes.Widget
+        SpecialWidget = self.classes.SpecialWidget
+
+        eq_(Widget("Foo"), Widget("Foo"))
+        assert Widget("Foo") != Widget("Bar")
+        assert Widget("Foo") != SpecialWidget("Foo")
+
+    def test_asdict_and_astuple(self):
+        Widget = self.classes.Widget
+        SpecialWidget = self.classes.SpecialWidget
+
+        widget = Widget("Foo")
+        eq_(dataclasses.asdict(widget), {"name": "Foo"})
+        eq_(dataclasses.astuple(widget), ("Foo",))
+
+        widget = SpecialWidget("Bar", magic=True)
+        eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True})
+        eq_(dataclasses.astuple(widget), ("Bar", True))
+
+    def test_round_trip(self):
+        Account = self.classes.Account
+        account = self.data_fixture()
+
+        with Session(testing.db) as session:
+            session.add(account)
+            session.commit()
+
+        with Session(testing.db) as session:
+            a = session.query(Account).get(42)
+            self.check_data_fixture(a)
+
+    def test_appending_to_relationship(self):
+        Account = self.classes.Account
+        Widget = self.classes.Widget
+        account = self.data_fixture()
+
+        with Session(testing.db) as session, session.begin():
+            session.add(account)
+            account.add_widget(Widget("Xyzzy"))
+
+        with Session(testing.db) as session:
+            a = session.query(Account).get(42)
+            eq_(a.widget_count, 3)
+            eq_(len(a.widgets), 3)
+
+    def test_filtering_on_relationship(self):
+        Account = self.classes.Account
+        Widget = self.classes.Widget
+        account = self.data_fixture()
+
+        with Session(testing.db) as session:
+            session.add(account)
+            session.commit()
+
+        with Session(testing.db) as session:
+            a = (
+                session.query(Account)
+                .join(Account.widgets)
+                .filter(Widget.name == "Foo")
+                .one()
+            )
+            self.check_data_fixture(a)
index f12d3fc084e46dd9ead4fe7989844d66b3765c83..f8133c6f05dd3302c7ea8d584af9f64c5e3fa120 100644 (file)
@@ -503,7 +503,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL):
             def foo(self):
                 pass
 
-        m._is_userland_descriptor(MyClass.foo)
+        assert m._is_userland_descriptor("foo", MyClass.foo)
 
     def test_configure_on_get_props_1(self):
         User, users = self.classes.User, self.tables.users