From: Václav Klusák Date: Mon, 17 Aug 2020 15:58:56 +0000 (-0400) Subject: Add support for classical mapping of dataclasses X-Git-Tag: rel_1_4_0b1~132^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=f806491fca4b08623d7fcffc375bd5cbe3790e5f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add support for classical mapping of dataclasses Added support for direct mapping of Python classes that are defined using the Python ``dataclasses`` decorator. See the section :ref:`mapping_dataclasses` for background. Pull request courtesy Václav Klusák. Fixes: #5027 Closes: #5516 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/5516 Pull-request-sha: bb48c63d1561ca48c954ad9f84a3eb2646571115 Change-Id: Ie33db2aae4adeeb5d99633fe926b9c30bab0b885 --- diff --git a/doc/build/changelog/unreleased_14/5027.rst b/doc/build/changelog/unreleased_14/5027.rst new file mode 100644 index 0000000000..6fd2bc9b25 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5027.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: usecase, orm + :tickets: 5027 + + Added support for direct mapping of Python classes that are defined using + the Python ``dataclasses`` decorator. See the section + :ref:`mapping_dataclasses` for background. Pull request courtesy Václav + Klusák. \ No newline at end of file diff --git a/doc/build/orm/mapping_styles.rst b/doc/build/orm/mapping_styles.rst index f76e452116..c156f08f19 100644 --- a/doc/build/orm/mapping_styles.rst +++ b/doc/build/orm/mapping_styles.rst @@ -120,6 +120,88 @@ user-defined class, linked together with a :func:`.mapper`. When we talk about "the behavior of :func:`.mapper`", this includes when using the Declarative system as well - it's still used, just behind the scenes. +.. _mapping_dataclasses: + +Mapping dataclasses and attrs +----------------------------- + +The dataclasses_ module, added in Python 3.7, provides a ``dataclass`` class +decorator to automatically generate boilerplate definitions of ``__init__()``, +``__eq__()``, ``__repr()__``, etc. methods. Another very popular library that does +the same, and much more, is attrs_. Classes defined using either of these can +be mapped with the following caveats. + +.. versionadded:: 1.4 Added support for direct mapping of Python dataclasses. + +The declarative "base" can't be used directly; a mapping function such as +:func:`_declarative.instrument_declarative` or :func:`_orm.mapper` may be +used. + +The ``dataclass`` decorator adds class attributes corresponding to simple default values. +This is done mostly as documentation, these attributes are not necessary for the function +of any of the generated methods. Mapping replaces these class attributes with property +descriptors. + +Mapping of frozen ``dataclass`` and ``attrs`` classes is not possible, because the +machinery used to enforce immutability interferes with loading. + +Example using classical mapping:: + + from __future__ import annotations + from dataclasses import dataclass, field + from typing import List + + from sqlalchemy import Column, ForeignKey, Integer, MetaData, String, Table + from sqlalchemy.orm import mapper, relationship + + @dataclass + class User: + id: int = field(init=False) + name: str = None + fullname: str = None + nickname: str = None + addresses: List[Address] = field(default_factory=list) + + @dataclass + class Address: + id: int = field(init=False) + user_id: int = field(init=False) + email_address: str = None + + metadata = MetaData() + + user = Table( + 'user', + metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('fullname', String(50)), + Column('nickname', String(12)), + ) + + address = Table( + 'address', + metadata, + Column('id', Integer, primary_key=True), + Column('user_id', Integer, ForeignKey('user.id')), + Column('email_address', String(50)), + ) + + mapper(User, user, properties={ + 'addresses': relationship(Address, backref='user', order_by=address.c.id), + }) + + mapper(Address, address) + +Note that ``User.id``, ``Address.id``, and ``Address.user_id`` are defined as ``field(init=False)``. +This means that parameters for these won't be added to ``__init__()`` methods, but +:class:`.Session` will still be able to set them after getting their values during flush +from autoincrement or other default value generator. You can also give them a +``None`` default value instead if you want to be able to specify their values in the constructor. + +.. _dataclasses: https://docs.python.org/3/library/dataclasses.html +.. _attrs: https://www.attrs.org/en/stable/ + Runtime Introspection of Mappings, Objects ========================================== diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 39cf86e34f..c2efa24a19 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -56,7 +56,7 @@ class DescriptorProperty(MapperProperty): if self.descriptor is None: desc = getattr(mapper.class_, self.key, None) - if mapper._is_userland_descriptor(desc): + if mapper._is_userland_descriptor(self.key, desc): self.descriptor = desc if self.descriptor is None: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 446f6790ed..755d4afc79 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -56,6 +56,12 @@ from ..sql import util as sql_util from ..sql import visitors from ..util import HasMemoized +try: + import dataclasses +except ImportError: + # The dataclasses module was added in Python 3.7 + dataclasses = None + _mapper_registry = weakref.WeakKeyDictionary() _already_compiling = False @@ -2632,7 +2638,7 @@ class Mapper( return result - def _is_userland_descriptor(self, obj): + def _is_userland_descriptor(self, assigned_name, obj): if isinstance( obj, ( @@ -2643,7 +2649,14 @@ class Mapper( ): return False else: - return True + return assigned_name not in self._dataclass_fields + + @HasMemoized.memoized_attribute + def _dataclass_fields(self): + if dataclasses is None or not dataclasses.is_dataclass(self.class_): + return frozenset() + + return {field.name for field in dataclasses.fields(self.class_)} def _should_exclude(self, name, assigned_name, local, column): """determine whether a particular property should be implicitly @@ -2656,16 +2669,19 @@ class Mapper( # check for class-bound attributes and/or descriptors, # either local or from an inherited class + # ignore dataclass field default values if local: if self.class_.__dict__.get( assigned_name, None ) is not None and self._is_userland_descriptor( - self.class_.__dict__[assigned_name] + assigned_name, self.class_.__dict__[assigned_name] ): return True else: attr = self.class_manager._get_class_attr_mro(assigned_name, None) - if attr is not None and self._is_userland_descriptor(attr): + if attr is not None and self._is_userland_descriptor( + assigned_name, attr + ): return True if ( diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 9b8caac2ef..4114137d4e 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1142,6 +1142,10 @@ class SuiteRequirements(Requirements): "Python version 3.7 or greater is required.", ) + @property + def dataclasses(self): + return self.python37 + @property def cpython(self): return exclusions.only_if( diff --git a/test/orm/test_dataclasses_py3k.py b/test/orm/test_dataclasses_py3k.py new file mode 100644 index 0000000000..1ac97b64a9 --- /dev/null +++ b/test/orm/test_dataclasses_py3k.py @@ -0,0 +1,227 @@ +from typing import List +from typing import Optional + +from sqlalchemy import Boolean +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy import testing +from sqlalchemy.orm import mapper +from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table + +try: + import dataclasses +except ImportError: + pass + + +class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL): + __requires__ = ("dataclasses",) + + run_setup_classes = "each" + run_setup_mappers = "each" + + @classmethod + def define_tables(cls, metadata): + Table( + "accounts", + metadata, + Column("account_id", Integer, primary_key=True), + Column("widget_count", Integer, nullable=False), + ) + Table( + "widgets", + metadata, + Column("widget_id", Integer, primary_key=True), + Column( + "account_id", + Integer, + ForeignKey("accounts.account_id"), + nullable=False, + ), + Column("type", String(30), nullable=False), + Column("name", String(30), nullable=False), + Column("magic", Boolean), + ) + + @classmethod + def setup_classes(cls): + @dataclasses.dataclass + class Widget: + name: Optional[str] = None + + @dataclasses.dataclass + class SpecialWidget(Widget): + magic: bool = False + + @dataclasses.dataclass + class Account: + account_id: int + widgets: List[Widget] = dataclasses.field(default_factory=list) + widget_count: int = dataclasses.field(init=False) + + def __post_init__(self): + self.widget_count = len(self.widgets) + + def add_widget(self, widget: Widget): + self.widgets.append(widget) + self.widget_count += 1 + + cls.classes.Account = Account + cls.classes.Widget = Widget + cls.classes.SpecialWidget = SpecialWidget + + @classmethod + def setup_mappers(cls): + accounts = cls.tables.accounts + widgets = cls.tables.widgets + + Account = cls.classes.Account + Widget = cls.classes.Widget + SpecialWidget = cls.classes.SpecialWidget + + mapper( + Widget, + widgets, + polymorphic_on=widgets.c.type, + polymorphic_identity="normal", + ) + mapper( + SpecialWidget, + widgets, + inherits=Widget, + polymorphic_identity="special", + ) + mapper(Account, accounts, properties={"widgets": relationship(Widget)}) + + def check_account_dataclass(self, obj): + assert dataclasses.is_dataclass(obj) + account_id, widgets, widget_count = dataclasses.fields(obj) + eq_(account_id.name, "account_id") + eq_(widget_count.name, "widget_count") + eq_(widgets.name, "widgets") + + def check_widget_dataclass(self, obj): + assert dataclasses.is_dataclass(obj) + (name,) = dataclasses.fields(obj) + eq_(name.name, "name") + + def check_special_widget_dataclass(self, obj): + assert dataclasses.is_dataclass(obj) + name, magic = dataclasses.fields(obj) + eq_(name.name, "name") + eq_(magic.name, "magic") + + def data_fixture(self): + Account = self.classes.Account + Widget = self.classes.Widget + SpecialWidget = self.classes.SpecialWidget + + return Account( + account_id=42, + widgets=[Widget("Foo"), SpecialWidget("Bar", magic=True)], + ) + + def check_data_fixture(self, account): + Widget = self.classes.Widget + SpecialWidget = self.classes.SpecialWidget + + self.check_account_dataclass(account) + eq_(account.account_id, 42) + eq_(account.widget_count, 2) + eq_(len(account.widgets), 2) + + foo, bar = account.widgets + + self.check_widget_dataclass(foo) + assert isinstance(foo, Widget) + eq_(foo.name, "Foo") + + self.check_special_widget_dataclass(bar) + assert isinstance(bar, SpecialWidget) + eq_(bar.name, "Bar") + eq_(bar.magic, True) + + def test_classes_are_still_dataclasses(self): + self.check_account_dataclass(self.classes.Account) + self.check_widget_dataclass(self.classes.Widget) + self.check_special_widget_dataclass(self.classes.SpecialWidget) + + def test_construction(self): + SpecialWidget = self.classes.SpecialWidget + + account = self.data_fixture() + self.check_data_fixture(account) + + widget = SpecialWidget() + eq_(widget.name, None) + eq_(widget.magic, False) + + def test_equality(self): + Widget = self.classes.Widget + SpecialWidget = self.classes.SpecialWidget + + eq_(Widget("Foo"), Widget("Foo")) + assert Widget("Foo") != Widget("Bar") + assert Widget("Foo") != SpecialWidget("Foo") + + def test_asdict_and_astuple(self): + Widget = self.classes.Widget + SpecialWidget = self.classes.SpecialWidget + + widget = Widget("Foo") + eq_(dataclasses.asdict(widget), {"name": "Foo"}) + eq_(dataclasses.astuple(widget), ("Foo",)) + + widget = SpecialWidget("Bar", magic=True) + eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True}) + eq_(dataclasses.astuple(widget), ("Bar", True)) + + def test_round_trip(self): + Account = self.classes.Account + account = self.data_fixture() + + with Session(testing.db) as session: + session.add(account) + session.commit() + + with Session(testing.db) as session: + a = session.query(Account).get(42) + self.check_data_fixture(a) + + def test_appending_to_relationship(self): + Account = self.classes.Account + Widget = self.classes.Widget + account = self.data_fixture() + + with Session(testing.db) as session, session.begin(): + session.add(account) + account.add_widget(Widget("Xyzzy")) + + with Session(testing.db) as session: + a = session.query(Account).get(42) + eq_(a.widget_count, 3) + eq_(len(a.widgets), 3) + + def test_filtering_on_relationship(self): + Account = self.classes.Account + Widget = self.classes.Widget + account = self.data_fixture() + + with Session(testing.db) as session: + session.add(account) + session.commit() + + with Session(testing.db) as session: + a = ( + session.query(Account) + .join(Account.widgets) + .filter(Widget.name == "Foo") + .one() + ) + self.check_data_fixture(a) diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index f12d3fc084..f8133c6f05 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -503,7 +503,7 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): def foo(self): pass - m._is_userland_descriptor(MyClass.foo) + assert m._is_userland_descriptor("foo", MyClass.foo) def test_configure_on_get_props_1(self): User, users = self.classes.User, self.tables.users