"the behavior of :func:`.mapper`", this includes when using the Declarative system
as well - it's still used, just behind the scenes.
+.. _mapping_dataclasses:
+
+Mapping dataclasses and attrs
+-----------------------------
+
+The dataclasses_ module, added in Python 3.7, provides a ``dataclass`` class
+decorator to automatically generate boilerplate definitions of ``__init__()``,
+``__eq__()``, ``__repr()__``, etc. methods. Another very popular library that does
+the same, and much more, is attrs_. Classes defined using either of these can
+be mapped with the following caveats.
+
+.. versionadded:: 1.4 Added support for direct mapping of Python dataclasses.
+
+The declarative "base" can't be used directly; a mapping function such as
+:func:`_declarative.instrument_declarative` or :func:`_orm.mapper` may be
+used.
+
+The ``dataclass`` decorator adds class attributes corresponding to simple default values.
+This is done mostly as documentation, these attributes are not necessary for the function
+of any of the generated methods. Mapping replaces these class attributes with property
+descriptors.
+
+Mapping of frozen ``dataclass`` and ``attrs`` classes is not possible, because the
+machinery used to enforce immutability interferes with loading.
+
+Example using classical mapping::
+
+ from __future__ import annotations
+ from dataclasses import dataclass, field
+ from typing import List
+
+ from sqlalchemy import Column, ForeignKey, Integer, MetaData, String, Table
+ from sqlalchemy.orm import mapper, relationship
+
+ @dataclass
+ class User:
+ id: int = field(init=False)
+ name: str = None
+ fullname: str = None
+ nickname: str = None
+ addresses: List[Address] = field(default_factory=list)
+
+ @dataclass
+ class Address:
+ id: int = field(init=False)
+ user_id: int = field(init=False)
+ email_address: str = None
+
+ metadata = MetaData()
+
+ user = Table(
+ 'user',
+ metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String(50)),
+ Column('fullname', String(50)),
+ Column('nickname', String(12)),
+ )
+
+ address = Table(
+ 'address',
+ metadata,
+ Column('id', Integer, primary_key=True),
+ Column('user_id', Integer, ForeignKey('user.id')),
+ Column('email_address', String(50)),
+ )
+
+ mapper(User, user, properties={
+ 'addresses': relationship(Address, backref='user', order_by=address.c.id),
+ })
+
+ mapper(Address, address)
+
+Note that ``User.id``, ``Address.id``, and ``Address.user_id`` are defined as ``field(init=False)``.
+This means that parameters for these won't be added to ``__init__()`` methods, but
+:class:`.Session` will still be able to set them after getting their values during flush
+from autoincrement or other default value generator. You can also give them a
+``None`` default value instead if you want to be able to specify their values in the constructor.
+
+.. _dataclasses: https://docs.python.org/3/library/dataclasses.html
+.. _attrs: https://www.attrs.org/en/stable/
+
Runtime Introspection of Mappings, Objects
==========================================
from ..sql import visitors
from ..util import HasMemoized
+try:
+ import dataclasses
+except ImportError:
+ # The dataclasses module was added in Python 3.7
+ dataclasses = None
+
_mapper_registry = weakref.WeakKeyDictionary()
_already_compiling = False
return result
- def _is_userland_descriptor(self, obj):
+ def _is_userland_descriptor(self, assigned_name, obj):
if isinstance(
obj,
(
):
return False
else:
- return True
+ return assigned_name not in self._dataclass_fields
+
+ @HasMemoized.memoized_attribute
+ def _dataclass_fields(self):
+ if dataclasses is None or not dataclasses.is_dataclass(self.class_):
+ return frozenset()
+
+ return {field.name for field in dataclasses.fields(self.class_)}
def _should_exclude(self, name, assigned_name, local, column):
"""determine whether a particular property should be implicitly
# check for class-bound attributes and/or descriptors,
# either local or from an inherited class
+ # ignore dataclass field default values
if local:
if self.class_.__dict__.get(
assigned_name, None
) is not None and self._is_userland_descriptor(
- self.class_.__dict__[assigned_name]
+ assigned_name, self.class_.__dict__[assigned_name]
):
return True
else:
attr = self.class_manager._get_class_attr_mro(assigned_name, None)
- if attr is not None and self._is_userland_descriptor(attr):
+ if attr is not None and self._is_userland_descriptor(
+ assigned_name, attr
+ ):
return True
if (
--- /dev/null
+from typing import List
+from typing import Optional
+
+from sqlalchemy import Boolean
+from sqlalchemy import ForeignKey
+from sqlalchemy import Integer
+from sqlalchemy import String
+from sqlalchemy import testing
+from sqlalchemy.orm import mapper
+from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import Table
+
+try:
+ import dataclasses
+except ImportError:
+ pass
+
+
+class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
+ __requires__ = ("dataclasses",)
+
+ run_setup_classes = "each"
+ run_setup_mappers = "each"
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "accounts",
+ metadata,
+ Column("account_id", Integer, primary_key=True),
+ Column("widget_count", Integer, nullable=False),
+ )
+ Table(
+ "widgets",
+ metadata,
+ Column("widget_id", Integer, primary_key=True),
+ Column(
+ "account_id",
+ Integer,
+ ForeignKey("accounts.account_id"),
+ nullable=False,
+ ),
+ Column("type", String(30), nullable=False),
+ Column("name", String(30), nullable=False),
+ Column("magic", Boolean),
+ )
+
+ @classmethod
+ def setup_classes(cls):
+ @dataclasses.dataclass
+ class Widget:
+ name: Optional[str] = None
+
+ @dataclasses.dataclass
+ class SpecialWidget(Widget):
+ magic: bool = False
+
+ @dataclasses.dataclass
+ class Account:
+ account_id: int
+ widgets: List[Widget] = dataclasses.field(default_factory=list)
+ widget_count: int = dataclasses.field(init=False)
+
+ def __post_init__(self):
+ self.widget_count = len(self.widgets)
+
+ def add_widget(self, widget: Widget):
+ self.widgets.append(widget)
+ self.widget_count += 1
+
+ cls.classes.Account = Account
+ cls.classes.Widget = Widget
+ cls.classes.SpecialWidget = SpecialWidget
+
+ @classmethod
+ def setup_mappers(cls):
+ accounts = cls.tables.accounts
+ widgets = cls.tables.widgets
+
+ Account = cls.classes.Account
+ Widget = cls.classes.Widget
+ SpecialWidget = cls.classes.SpecialWidget
+
+ mapper(
+ Widget,
+ widgets,
+ polymorphic_on=widgets.c.type,
+ polymorphic_identity="normal",
+ )
+ mapper(
+ SpecialWidget,
+ widgets,
+ inherits=Widget,
+ polymorphic_identity="special",
+ )
+ mapper(Account, accounts, properties={"widgets": relationship(Widget)})
+
+ def check_account_dataclass(self, obj):
+ assert dataclasses.is_dataclass(obj)
+ account_id, widgets, widget_count = dataclasses.fields(obj)
+ eq_(account_id.name, "account_id")
+ eq_(widget_count.name, "widget_count")
+ eq_(widgets.name, "widgets")
+
+ def check_widget_dataclass(self, obj):
+ assert dataclasses.is_dataclass(obj)
+ (name,) = dataclasses.fields(obj)
+ eq_(name.name, "name")
+
+ def check_special_widget_dataclass(self, obj):
+ assert dataclasses.is_dataclass(obj)
+ name, magic = dataclasses.fields(obj)
+ eq_(name.name, "name")
+ eq_(magic.name, "magic")
+
+ def data_fixture(self):
+ Account = self.classes.Account
+ Widget = self.classes.Widget
+ SpecialWidget = self.classes.SpecialWidget
+
+ return Account(
+ account_id=42,
+ widgets=[Widget("Foo"), SpecialWidget("Bar", magic=True)],
+ )
+
+ def check_data_fixture(self, account):
+ Widget = self.classes.Widget
+ SpecialWidget = self.classes.SpecialWidget
+
+ self.check_account_dataclass(account)
+ eq_(account.account_id, 42)
+ eq_(account.widget_count, 2)
+ eq_(len(account.widgets), 2)
+
+ foo, bar = account.widgets
+
+ self.check_widget_dataclass(foo)
+ assert isinstance(foo, Widget)
+ eq_(foo.name, "Foo")
+
+ self.check_special_widget_dataclass(bar)
+ assert isinstance(bar, SpecialWidget)
+ eq_(bar.name, "Bar")
+ eq_(bar.magic, True)
+
+ def test_classes_are_still_dataclasses(self):
+ self.check_account_dataclass(self.classes.Account)
+ self.check_widget_dataclass(self.classes.Widget)
+ self.check_special_widget_dataclass(self.classes.SpecialWidget)
+
+ def test_construction(self):
+ SpecialWidget = self.classes.SpecialWidget
+
+ account = self.data_fixture()
+ self.check_data_fixture(account)
+
+ widget = SpecialWidget()
+ eq_(widget.name, None)
+ eq_(widget.magic, False)
+
+ def test_equality(self):
+ Widget = self.classes.Widget
+ SpecialWidget = self.classes.SpecialWidget
+
+ eq_(Widget("Foo"), Widget("Foo"))
+ assert Widget("Foo") != Widget("Bar")
+ assert Widget("Foo") != SpecialWidget("Foo")
+
+ def test_asdict_and_astuple(self):
+ Widget = self.classes.Widget
+ SpecialWidget = self.classes.SpecialWidget
+
+ widget = Widget("Foo")
+ eq_(dataclasses.asdict(widget), {"name": "Foo"})
+ eq_(dataclasses.astuple(widget), ("Foo",))
+
+ widget = SpecialWidget("Bar", magic=True)
+ eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True})
+ eq_(dataclasses.astuple(widget), ("Bar", True))
+
+ def test_round_trip(self):
+ Account = self.classes.Account
+ account = self.data_fixture()
+
+ with Session(testing.db) as session:
+ session.add(account)
+ session.commit()
+
+ with Session(testing.db) as session:
+ a = session.query(Account).get(42)
+ self.check_data_fixture(a)
+
+ def test_appending_to_relationship(self):
+ Account = self.classes.Account
+ Widget = self.classes.Widget
+ account = self.data_fixture()
+
+ with Session(testing.db) as session, session.begin():
+ session.add(account)
+ account.add_widget(Widget("Xyzzy"))
+
+ with Session(testing.db) as session:
+ a = session.query(Account).get(42)
+ eq_(a.widget_count, 3)
+ eq_(len(a.widgets), 3)
+
+ def test_filtering_on_relationship(self):
+ Account = self.classes.Account
+ Widget = self.classes.Widget
+ account = self.data_fixture()
+
+ with Session(testing.db) as session:
+ session.add(account)
+ session.commit()
+
+ with Session(testing.db) as session:
+ a = (
+ session.query(Account)
+ .join(Account.widgets)
+ .filter(Widget.name == "Foo")
+ .one()
+ )
+ self.check_data_fixture(a)