From: Václav Klusák Date: Sun, 16 Aug 2020 14:40:42 +0000 (+0200) Subject: Add support for classical mapping of dataclasses, Fixes: #5027 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F5516%2Fhead;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add support for classical mapping of dataclasses, Fixes: #5027 - Ignore simple default value class attributes. - Add tests for classical mapping of dataclasses. - Document limitations and show example. --- diff --git a/doc/build/orm/mapping_styles.rst b/doc/build/orm/mapping_styles.rst index f76e452116..c5dae19ab2 100644 --- a/doc/build/orm/mapping_styles.rst +++ b/doc/build/orm/mapping_styles.rst @@ -120,6 +120,86 @@ user-defined class, linked together with a :func:`.mapper`. When we talk about "the behavior of :func:`.mapper`", this includes when using the Declarative system as well - it's still used, just behind the scenes. +Mapping dataclasses and attrs +----------------------------- + +The dataclasses_ module, added in Python 3.7, provides a ``dataclass`` class +decorator to automatically generate boilerplate definitions of ``__init__()``, +``__eq__()``, ``__repr()__``, etc. methods. Another very popular library that does +the same, and much more, is attrs_. Classes defined using either of these can +be mapped with the following caveats. + +.. note:: + + * Only classical mapping is possible, not Declarative. Classes inheriting from + Declarative base would get processed by SQLAlchemy before being handed to the + ``dataclass`` or ``attrs`` decorator, and this would interfere with its function. + + * The ``dataclass`` decorator adds class attributes corresponding to simple default values. + This is done mostly as documentation, these attributes are not necessary for the function + of any of the generated methods. Mapping replaces these class attributes with property + descriptors. + + * Mapping of frozen ``dataclass`` and ``attrs`` classes is not possible, because the + machinery used to enforce immutability interferes with loading. + +Example:: + + from __future__ import annotations + from dataclasses import dataclass, field + from typing import List + + from sqlalchemy import Column, ForeignKey, Integer, MetaData, String, Table + from sqlalchemy.orm import mapper, relationship + + @dataclass + class User: + id: int = field(init=False) + name: str = None + fullname: str = None + nickname: str = None + addresses: List[Address] = field(default_factory=list) + + @dataclass + class Address: + id: int = field(init=False) + user_id: int = field(init=False) + email_address: str = None + + metadata = MetaData() + + user = Table( + 'user', + metadata, + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('fullname', String(50)), + Column('nickname', String(12)), + ) + + address = Table( + 'address', + metadata, + Column('id', Integer, primary_key=True), + Column('user_id', Integer, ForeignKey('user.id')), + Column('email_address', String(50)), + ) + + mapper(User, user, properties={ + 'addresses': relationship(Address, backref='user', order_by=address.c.id), + }) + + mapper(Address, address) + +Note that ``User.id``, ``Address.id``, and ``Address.user_id`` are defined as ``field(init=False)``. +This means that parameters for these won't be added to ``__init__()`` methods, but +:class:`.Session` will still be able to set them after getting their values during flush +from autoincrement or other default value generator. You can also give them a +``None`` default value instead if you want to be able to specify their values in the constructor. + +.. _dataclasses: https://docs.python.org/3/library/dataclasses.html +.. _attrs: https://www.attrs.org/en/stable/ + Runtime Introspection of Mappings, Objects ========================================== diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 446f6790ed..b21cd9076a 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -22,6 +22,12 @@ import sys import types import weakref +try: + import dataclasses +except ImportError: + # The dataclasses module was added in Python 3.7 + dataclasses = None + from . import attributes from . import exc as orm_exc from . import instrumentation @@ -2645,6 +2651,20 @@ class Mapper( else: return True + @HasMemoized.memoized_attribute + def _is_dataclass(self): + if dataclasses is not None: + return dataclasses.is_dataclass(self.class_) + else: + return False + + @HasMemoized.memoized_attribute + def _dataclass_fields(self): + return set(field.name for field in dataclasses.fields(self.class_)) + + def _is_dataclass_field(self, assigned_name): + return self._is_dataclass and assigned_name in self._dataclass_fields + def _should_exclude(self, name, assigned_name, local, column): """determine whether a particular property should be implicitly present on the class. @@ -2656,16 +2676,21 @@ class Mapper( # check for class-bound attributes and/or descriptors, # either local or from an inherited class + # ignore dataclass field default values if local: if self.class_.__dict__.get( assigned_name, None ) is not None and self._is_userland_descriptor( self.class_.__dict__[assigned_name] - ): + ) and not self._is_dataclass_field(assigned_name): return True else: attr = self.class_manager._get_class_attr_mro(assigned_name, None) - if attr is not None and self._is_userland_descriptor(attr): + if ( + attr is not None + and self._is_userland_descriptor(attr) + and not self._is_dataclass_field(assigned_name) + ): return True if ( diff --git a/test/orm/test_dataclasses_py3k.py b/test/orm/test_dataclasses_py3k.py new file mode 100644 index 0000000000..6ffe745c9e --- /dev/null +++ b/test/orm/test_dataclasses_py3k.py @@ -0,0 +1,232 @@ +from typing import List, Optional + +from sqlalchemy import Boolean +from sqlalchemy import ForeignKey +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import mapper +from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session +from sqlalchemy import testing +from sqlalchemy.testing import fixtures +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table + +import pytest + + +# The dataclasses module was added in Python 3.7 +dataclasses = pytest.importorskip("dataclasses") + + +class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL): + run_setup_classes = "each" + run_setup_mappers = "each" + + @classmethod + def define_tables(cls, metadata): + Table( + "accounts", + metadata, + Column("account_id", Integer, primary_key=True), + Column("widget_count", Integer, nullable=False), + ) + Table( + "widgets", + metadata, + Column("widget_id", Integer, primary_key=True), + Column( + "account_id", + Integer, + ForeignKey("accounts.account_id"), + nullable=False, + ), + Column("type", String, nullable=False), + Column("name", String, nullable=False), + Column("magic", Boolean), + ) + + @classmethod + def setup_classes(cls): + @dataclasses.dataclass + class Widget: + name: Optional[str] = None + + @dataclasses.dataclass + class SpecialWidget(Widget): + magic: bool = False + + @dataclasses.dataclass + class Account: + account_id: int + widgets: List[Widget] = dataclasses.field(default_factory=list) + widget_count: int = dataclasses.field(init=False) + + def __post_init__(self): + self.widget_count = len(self.widgets) + + def add_widget(self, widget: Widget): + self.widgets.append(widget) + self.widget_count += 1 + + cls.classes.Account = Account + cls.classes.Widget = Widget + cls.classes.SpecialWidget = SpecialWidget + + @classmethod + def setup_mappers(cls): + accounts = cls.tables.accounts + widgets = cls.tables.widgets + + Account = cls.classes.Account + Widget = cls.classes.Widget + SpecialWidget = cls.classes.SpecialWidget + + mapper( + Widget, + widgets, + polymorphic_on=widgets.c.type, + polymorphic_identity="normal", + ) + mapper( + SpecialWidget, + widgets, + inherits=Widget, + polymorphic_identity="special", + ) + mapper(Account, accounts, properties={ + "widgets": relationship(Widget), + }) + + def check_account_dataclass(self, obj): + assert dataclasses.is_dataclass(obj) + account_id, widgets, widget_count = dataclasses.fields(obj) + assert account_id.name == "account_id" + assert widget_count.name == "widget_count" + assert widgets.name == "widgets" + + def check_widget_dataclass(self, obj): + assert dataclasses.is_dataclass(obj) + name, = dataclasses.fields(obj) + assert name.name == "name" + + def check_special_widget_dataclass(self, obj): + assert dataclasses.is_dataclass(obj) + name, magic = dataclasses.fields(obj) + assert name.name == "name" + assert magic.name == "magic" + + def data_fixture(self): + Account = self.classes.Account + Widget = self.classes.Widget + SpecialWidget = self.classes.SpecialWidget + + return Account( + account_id=42, + widgets=[ + Widget("Foo"), + SpecialWidget("Bar", magic=True), + ], + ) + + def check_data_fixture(self, account): + Widget = self.classes.Widget + SpecialWidget = self.classes.SpecialWidget + + self.check_account_dataclass(account) + assert account.account_id == 42 + assert account.widget_count == len(account.widgets) == 2 + + foo, bar = account.widgets + + self.check_widget_dataclass(foo) + assert isinstance(foo, Widget) + assert foo.name == "Foo" + + self.check_special_widget_dataclass(bar) + assert isinstance(bar, SpecialWidget) + assert bar.name == "Bar" + assert bar.magic == True + + def test_classes_are_still_dataclasses(self): + self.check_account_dataclass(self.classes.Account) + self.check_widget_dataclass(self.classes.Widget) + self.check_special_widget_dataclass(self.classes.SpecialWidget) + + def test_construction(self): + SpecialWidget = self.classes.SpecialWidget + + account = self.data_fixture() + self.check_data_fixture(account) + + widget = SpecialWidget() + assert widget.name == None + assert widget.magic == False + + def test_equality(self): + Widget = self.classes.Widget + SpecialWidget = self.classes.SpecialWidget + + assert Widget("Foo") == Widget("Foo") + assert Widget("Foo") != Widget("Bar") + assert Widget("Foo") != SpecialWidget("Foo") + + def test_asdict_and_astuple(self): + Widget = self.classes.Widget + SpecialWidget = self.classes.SpecialWidget + + widget = Widget("Foo") + assert dataclasses.asdict(widget) == {"name": "Foo"} + assert dataclasses.astuple(widget) == ("Foo",) + + widget = SpecialWidget("Bar", magic=True) + assert dataclasses.asdict(widget) == { + "name": "Bar", "magic": True + } + assert dataclasses.astuple(widget) == ("Bar", True) + + def test_round_trip(self): + Account = self.classes.Account + account = self.data_fixture() + + session = Session(testing.db) + session.add(account) + session.commit() + session.close() + + a = session.query(Account).get(42) + self.check_data_fixture(a) + + def test_appending_to_relationship(self): + Account = self.classes.Account + Widget = self.classes.Widget + account = self.data_fixture() + + session = Session(testing.db) + session.add(account) + session.commit() + + account.add_widget(Widget("Xyzzy")) + session.commit() + session.close() + + a = session.query(Account).get(42) + assert a.widget_count == len(a.widgets) == 3 + + def test_filtering_on_relationship(self): + Account = self.classes.Account + Widget = self.classes.Widget + account = self.data_fixture() + + session = Session(testing.db) + session.add(account) + session.commit() + session.close() + + a = ( + session.query(Account) + .join(Account.widgets) + .filter(Widget.name == "Foo") + .one() + ) + self.check_data_fixture(a)