]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support callable mapped attributes in dataclass mixins
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Mar 2021 16:55:43 +0000 (12:55 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 19 Mar 2021 16:57:40 +0000 (12:57 -0400)
Added support for the :class:`_orm.declared_attr` object to work in the
context of dataclass fields.

Fixes: #6100
Change-Id: Ifaf4a6482c866d6cfee99d8bc2c6294d923460d7

doc/build/changelog/unreleased_14/6100.rst [new file with mode: 0644]
doc/build/orm/declarative_mixins.rst
doc/build/orm/mapping_styles.rst
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
test/orm/test_dataclasses_py3k.py

diff --git a/doc/build/changelog/unreleased_14/6100.rst b/doc/build/changelog/unreleased_14/6100.rst
new file mode 100644 (file)
index 0000000..197805b
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: usecase, orm, dataclasses
+    :tickets: 6100
+
+    Added support for the :class:`_orm.declared_attr` object to work in the
+    context of dataclass fields.
+
+    .. seealso::
+
+        :ref:`orm_declarative_dataclasses_mixin`
index a0229fe88bbb4570178912dcb403393464c22713..309c3226020f3544434ddeb2bc556ab0327523fd 100644 (file)
@@ -154,6 +154,7 @@ will resolve them at class construction time::
         __tablename__='test'
         id =  Column(Integer, primary_key=True)
 
+.. _orm_declarative_mixins_relationships:
 
 Mixing in Relationships
 ~~~~~~~~~~~~~~~~~~~~~~~
index d6c7d3280d212fa644655f3dcd072fbc023a97c5..09f01102bf7920a78562703f8f97f904dc765515 100644 (file)
@@ -406,6 +406,72 @@ association::
             default=None, metadata={"sa": Column(String(50))}
         )
 
+.. _orm_declarative_dataclasses_mixin:
+
+Using Declarative Mixins with Dataclasses
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+In the section :ref:`orm_mixins_toplevel`, Declarative Mixin classes
+are introduced.  One requirement of declarative mixins is that certain
+constructs that can't be easily duplicated must be given as callables,
+using the :class:`_orm.declared_attr` decorator, such as in the
+example at :ref:`orm_declarative_mixins_relationships`::
+
+    class RefTargetMixin(object):
+        @declared_attr
+        def target_id(cls):
+            return Column('target_id', ForeignKey('target.id'))
+
+        @declared_attr
+        def target(cls):
+            return relationship("Target")
+
+This form is supported within the Dataclasses ``field()`` object by using
+a lambda to indicate the SQLAlchemy construct inside the ``field()``.
+Using :func:`_orm.declared_attr` to surround the lambda is optional.
+If we wanted to produce our ``User`` class above where the ORM fields
+came from a mixin that is itself a dataclass, the form would be::
+
+    @dataclass
+    class UserMixin:
+        __tablename__ = "user"
+
+        __sa_dataclass_metadata_key__ = "sa"
+
+        id: int = field(
+            init=False, metadata={"sa": Column(Integer, primary_key=True)}
+        )
+
+        addresses: List[Address] = field(
+            default_factory=list, metadata={"sa": lambda: relationship("Address")}
+        )
+
+    @dataclass
+    class AddressMixin:
+        __tablename__ = "address"
+        __sa_dataclass_metadata_key__ = "sa"
+        id: int = field(
+            init=False, metadata={"sa": Column(Integer, primary_key=True)}
+        )
+        user_id: int = field(
+            init=False, metadata={"sa": lambda: Column(ForeignKey("user.id"))}
+        )
+        email_address: str = field(
+            default=None, metadata={"sa": Column(String(50))}
+        )
+
+    @mapper_registry.mapped
+    class User(UserMixin):
+        pass
+
+    @mapper_registry.mapped
+    class Address(AddressMixin):
+      pass
+
+.. versionadded:: 1.4.2  Added support for "declared attr" style mixin attributes,
+   namely :func:`_orm.relationship` constructs as well as :class:`_schema.Column`
+   objects with foreign key declarations, to be used within "Dataclasses
+   with Declarative Table" style mappings.
 
 .. _orm_declarative_attrs_imperative_table:
 
index 0266c973aecc77d83d55b632bce13cbaeab1f5d4..ef53e2d399ad72925e4201df7144d4af3ed475a2 100644 (file)
@@ -126,27 +126,34 @@ class declared_attr(interfaces._MappedAttribute, property):
     """Mark a class-level method as representing the definition of
     a mapped property or special declarative member name.
 
-    @declared_attr turns the attribute into a scalar-like
-    property that can be invoked from the uninstantiated class.
-    Declarative treats attributes specifically marked with
-    @declared_attr as returning a construct that is specific
-    to mapping or declarative table configuration.  The name
-    of the attribute is that of what the non-dynamic version
-    of the attribute would be.
-
-    @declared_attr is more often than not applicable to mixins,
-    to define relationships that are to be applied to different
-    implementors of the class::
+    :class:`_orm.declared_attr` is typically applied as a decorator to a class
+    level method, turning the attribute into a scalar-like property that can be
+    invoked from the uninstantiated class. The Declarative mapping process
+    looks for these :class:`_orm.declared_attr` callables as it scans classe,
+    and assumes any attribute marked with :class:`_orm.declared_attr` will be a
+    callable that will produce an object specific to the Declarative mapping or
+    table configuration.
+
+    :class:`_orm.declared_attr` is usually applicable to mixins, to define
+    relationships that are to be applied to different implementors of the
+    class. It is also used to define :class:`_schema.Column` objects that
+    include the :class:`_schema.ForeignKey` construct, as these cannot be
+    easily reused across different mappings.  The example below illustrates
+    both::
 
         class ProvidesUser(object):
             "A mixin that adds a 'user' relationship to classes."
 
+            @declared_attr
+            def user_id(self):
+                return Column(ForeignKey("user_account.id"))
+
             @declared_attr
             def user(self):
                 return relationship("User")
 
-    It also can be applied to mapped classes, such as to provide
-    a "polymorphic" scheme for inheritance::
+    :class:`_orm.declared_attr` can also be applied to mapped classes, such as
+    to provide a "polymorphic" scheme for inheritance::
 
         class Employee(Base):
             id = Column(Integer, primary_key=True)
@@ -166,12 +173,43 @@ class declared_attr(interfaces._MappedAttribute, property):
                 else:
                     return {"polymorphic_identity":cls.__name__}
 
-    """
+    To use :class:`_orm.declared_attr` inside of a Python dataclass
+    as discussed at :ref:`orm_declarative_dataclasses_declarative_table`,
+    it may be placed directly inside the field metadata using a lambda::
+
+        @dataclass
+        class AddressMixin:
+            __sa_dataclass_metadata_key__ = "sa"
+
+            user_id: int = field(
+                init=False, metadata={"sa": declared_attr(lambda: Column(ForeignKey("user.id")))}
+            )
+            user: User = field(
+                init=False, metadata={"sa": declared_attr(lambda: relationship(User))}
+            )
+
+    :class:`_orm.declared_attr` also may be omitted from this form using a
+    lambda directly, as in::
+
+        user: User = field(
+            init=False, metadata={"sa": lambda: relationship(User)}
+        )
+
+    .. seealso::
+
+        :ref:`orm_mixins_toplevel` - illustrates how to use Declarative Mixins
+        which is the primary use case for :class:`_orm.declared_attr`
+
+        :ref:`orm_declarative_dataclasses_mixin` - illustrates special forms
+        for use with Python dataclasses
+
+    """  # noqa E501
 
-    def __init__(self, fget, cascading=False):
+    def __init__(self, fget, cascading=False, _is_dataclass=False):
         super(declared_attr, self).__init__(fget)
         self.__doc__ = fget.__doc__
         self._cascading = cascading
+        self._is_dataclass = _is_dataclass
 
     def __get__(desc, self, cls):
         # the declared_attr needs to make use of a cache that exists
index 0a73288fd6421256f7acbd61b0ee134c7c21f6d4..e55056fdff4979a8efc750f676d0ff4c80b45674 100644 (file)
@@ -356,11 +356,16 @@ class _ClassScanMapperConfig(_MapperConfig):
             absent = object()
 
             def attribute_is_overridden(key, obj):
+                if _is_declarative_props(obj):
+                    obj = obj.fget
+
                 # this function likely has some failure modes still if
                 # someone is doing a deep mixing of the same attribute
                 # name as plain Python attribute vs. dataclass field.
 
                 ret = local_datacls_fields.get(key, absent)
+                if _is_declarative_props(ret):
+                    ret = ret.fget
 
                 if ret is obj:
                     return False
@@ -414,9 +419,9 @@ class _ClassScanMapperConfig(_MapperConfig):
                 for field in util.local_dataclass_fields(cls):
                     if sa_dataclass_metadata_key in field.metadata:
                         field_names.add(field.name)
-                        yield field.name, field.metadata[
-                            sa_dataclass_metadata_key
-                        ]
+                        yield field.name, _as_dc_declaredattr(
+                            field.metadata, sa_dataclass_metadata_key
+                        )
                 for name, obj in vars(cls).items():
                     if name not in field_names:
                         yield name, obj
@@ -507,7 +512,8 @@ class _ClassScanMapperConfig(_MapperConfig):
                             "Mapper properties (i.e. deferred,"
                             "column_property(), relationship(), etc.) must "
                             "be declared as @declared_attr callables "
-                            "on declarative mixin classes."
+                            "on declarative mixin classes.  For dataclass "
+                            "field() objects, use a lambda:"
                         )
                     elif _is_declarative_props(obj):
                         if obj._cascading:
@@ -530,8 +536,12 @@ class _ClassScanMapperConfig(_MapperConfig):
                             ] = ret = obj.__get__(obj, cls)
                             setattr(cls, name, ret)
                         else:
-                            # access attribute using normal class access
-                            ret = getattr(cls, name)
+                            if obj._is_dataclass:
+                                ret = obj.fget()
+                            else:
+
+                                # access attribute using normal class access
+                                ret = getattr(cls, name)
 
                             # correct for proxies created from hybrid_property
                             # or similar.  note there is no known case that
@@ -567,6 +577,10 @@ class _ClassScanMapperConfig(_MapperConfig):
                     # assert that the dataclass-enabled resolver agrees
                     # with what we are seeing
                     assert not attribute_is_overridden(name, obj)
+
+                    if _is_declarative_props(obj):
+                        obj = obj.fget()
+
                     dict_[name] = obj
 
         if inherited_table_args and not tablename:
@@ -604,7 +618,8 @@ class _ClassScanMapperConfig(_MapperConfig):
                     raise exc.InvalidRequestError(
                         "Columns with foreign keys to other columns "
                         "must be declared as @declared_attr callables "
-                        "on declarative mixin classes. "
+                        "on declarative mixin classes.  For dataclass "
+                        "field() objects, use a lambda:."
                     )
                 elif name not in dict_ and not (
                     "__table__" in dict_
@@ -957,6 +972,21 @@ class _ClassScanMapperConfig(_MapperConfig):
         )
 
 
+@util.preload_module("sqlalchemy.orm.decl_api")
+def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key):
+    # wrap lambdas inside dataclass fields inside an ad-hoc declared_attr.
+    # we can't write it because field.metadata is immutable :( so we have
+    # to go through extra trouble to compare these
+    decl_api = util.preloaded.orm_decl_api
+    obj = field_metadata[sa_dataclass_metadata_key]
+    if callable(obj) and not isinstance(obj, decl_api.declared_attr):
+        return decl_api.declared_attr(obj, _is_dataclass=True)
+    elif isinstance(obj, decl_api.declared_attr):
+        obj._is_dataclass = True
+        return obj
+    return obj
+
+
 class _DeferredMapperConfig(_ClassScanMapperConfig):
     _configs = util.OrderedDict()
 
index ef1c12050edf4efe78ab8677eff7988ccc268948..56091505c3ff40b11696474149bc92b977573dda 100644 (file)
@@ -3,6 +3,7 @@ from typing import Optional
 
 from sqlalchemy import Boolean
 from sqlalchemy import ForeignKey
+from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import String
 from sqlalchemy import testing
@@ -521,6 +522,243 @@ class FieldEmbeddedWMixinTest(FieldEmbeddedDeclarativeDataclassesTest):
         eq_(dataclasses.astuple(widget), (None, "Bar", True))
 
 
+class FieldEmbeddedMixinWLambdaTest(fixtures.DeclarativeMappedTest):
+    __requires__ = ("dataclasses",)
+
+    run_setup_classes = "each"
+    run_setup_mappers = "each"
+
+    @classmethod
+    def setup_classes(cls):
+        declarative = cls.DeclarativeBasic.registry.mapped
+
+        @dataclasses.dataclass
+        class WidgetDC:
+
+            __sa_dataclass_metadata_key__ = "sa"
+
+            widget_id: int = dataclasses.field(
+                init=False,
+                metadata={"sa": Column(Integer, primary_key=True)},
+            )
+
+            # fk on mixin
+            account_id: int = dataclasses.field(
+                init=False,
+                metadata={
+                    "sa": lambda: Column(
+                        Integer,
+                        ForeignKey("accounts.account_id"),
+                        nullable=False,
+                    )
+                },
+            )
+
+        @declarative
+        @dataclasses.dataclass
+        class Widget(WidgetDC):
+            __tablename__ = "widgets"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            type = Column(String(30), nullable=False)
+
+            name: Optional[str] = dataclasses.field(
+                default=None,
+                metadata={"sa": Column(String(30), nullable=False)},
+            )
+            __mapper_args__ = dict(
+                polymorphic_on="type",
+                polymorphic_identity="normal",
+            )
+
+        @dataclasses.dataclass
+        class AccountDC:
+
+            __sa_dataclass_metadata_key__ = "sa"
+
+            # relationship on mixin
+            widgets: List[Widget] = dataclasses.field(
+                default_factory=list,
+                metadata={"sa": lambda: relationship("Widget")},
+            )
+
+            account_id: int = dataclasses.field(
+                init=False,
+                metadata={"sa": Column(Integer, primary_key=True)},
+            )
+            widget_count: int = dataclasses.field(
+                init=False,
+                metadata={
+                    "sa": Column("widget_count", Integer, nullable=False)
+                },
+            )
+
+        @declarative
+        class Account(AccountDC):
+            __tablename__ = "accounts"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            def __post_init__(self):
+                self.widget_count = len(self.widgets)
+
+            def add_widget(self, widget: Widget):
+                self.widgets.append(widget)
+                self.widget_count += 1
+
+        @declarative
+        @dataclasses.dataclass
+        class User:
+            __tablename__ = "user"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            user_id: int = dataclasses.field(
+                init=False,
+                metadata={"sa": Column(Integer, primary_key=True)},
+            )
+
+            # fk w declared attr on mapped class
+            account_id: int = dataclasses.field(
+                init=False,
+                metadata={
+                    "sa": lambda: Column(
+                        Integer,
+                        ForeignKey("accounts.account_id"),
+                        nullable=False,
+                    )
+                },
+            )
+
+        cls.classes["Account"] = Account
+        cls.classes["Widget"] = Widget
+        cls.classes["User"] = User
+
+    def test_setup(self):
+        Account, Widget, User = self.classes("Account", "Widget", "User")
+
+        assert "account_id" in Widget.__table__.c
+        assert list(Widget.__table__.c.account_id.foreign_keys)[0].references(
+            Account.__table__
+        )
+        assert inspect(Account).relationships.widgets.mapper is inspect(Widget)
+
+        assert "account_id" in User.__table__.c
+        assert list(User.__table__.c.account_id.foreign_keys)[0].references(
+            Account.__table__
+        )
+
+
+class FieldEmbeddedMixinWDeclaredAttrTest(FieldEmbeddedMixinWLambdaTest):
+    __requires__ = ("dataclasses",)
+
+    @classmethod
+    def setup_classes(cls):
+        declarative = cls.DeclarativeBasic.registry.mapped
+
+        @dataclasses.dataclass
+        class WidgetDC:
+
+            __sa_dataclass_metadata_key__ = "sa"
+
+            widget_id: int = dataclasses.field(
+                init=False,
+                metadata={"sa": Column(Integer, primary_key=True)},
+            )
+
+            # fk on mixin
+            account_id: int = dataclasses.field(
+                init=False,
+                metadata={
+                    "sa": declared_attr(
+                        lambda: Column(
+                            Integer,
+                            ForeignKey("accounts.account_id"),
+                            nullable=False,
+                        )
+                    )
+                },
+            )
+
+        @declarative
+        @dataclasses.dataclass
+        class Widget(WidgetDC):
+            __tablename__ = "widgets"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            type = Column(String(30), nullable=False)
+
+            name: Optional[str] = dataclasses.field(
+                default=None,
+                metadata={"sa": Column(String(30), nullable=False)},
+            )
+            __mapper_args__ = dict(
+                polymorphic_on="type",
+                polymorphic_identity="normal",
+            )
+
+        @dataclasses.dataclass
+        class AccountDC:
+
+            __sa_dataclass_metadata_key__ = "sa"
+
+            # relationship on mixin
+            widgets: List[Widget] = dataclasses.field(
+                default_factory=list,
+                metadata={"sa": declared_attr(lambda: relationship("Widget"))},
+            )
+
+            account_id: int = dataclasses.field(
+                init=False,
+                metadata={"sa": Column(Integer, primary_key=True)},
+            )
+            widget_count: int = dataclasses.field(
+                init=False,
+                metadata={
+                    "sa": Column("widget_count", Integer, nullable=False)
+                },
+            )
+
+        @declarative
+        class Account(AccountDC):
+            __tablename__ = "accounts"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            def __post_init__(self):
+                self.widget_count = len(self.widgets)
+
+            def add_widget(self, widget: Widget):
+                self.widgets.append(widget)
+                self.widget_count += 1
+
+        @declarative
+        @dataclasses.dataclass
+        class User:
+            __tablename__ = "user"
+            __sa_dataclass_metadata_key__ = "sa"
+
+            user_id: int = dataclasses.field(
+                init=False,
+                metadata={"sa": Column(Integer, primary_key=True)},
+            )
+
+            # fk w declared attr on mapped class
+            account_id: int = dataclasses.field(
+                init=False,
+                metadata={
+                    "sa": declared_attr(
+                        lambda: Column(
+                            Integer,
+                            ForeignKey("accounts.account_id"),
+                            nullable=False,
+                        )
+                    )
+                },
+            )
+
+        cls.classes["Account"] = Account
+        cls.classes["Widget"] = Widget
+        cls.classes["User"] = User
+
+
 class PropagationFromMixinTest(fixtures.TestBase):
     __requires__ = ("dataclasses",)