From: Mike Bayer Date: Fri, 23 Apr 2021 01:45:10 +0000 (-0400) Subject: implement declared_attr superclass assignment check for dataclasses X-Git-Tag: rel_1_4_12~34^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=8bd9d72271ec9a2fec5749c428ef5ad6e9dc2175;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement declared_attr superclass assignment check for dataclasses Adjusted the declarative scan for dataclasses so that the inheritance behavior of :func:`_orm.declared_attr` established on a mixin, when using the new form of having it inside of a ``dataclasses.field()`` construct and not actually a descriptor attribute on the class, correctly accommodates the case when the target class to be mapped is a subclass of an existing mapped class which has already mapped that :func:`_orm.declared_attr`, and therefore should not be re-applied to this class. Also, as changed in ed3f2c617239668d we now have an "is_dataclass" boolean set as we iterate through attrs so we can remove this from declared_attr. Fixes: #6346 Change-Id: Iec75bdefd3bff7d8a9a157c8dd744ac14ff15ea8 --- diff --git a/doc/build/changelog/unreleased_14/6346.rst b/doc/build/changelog/unreleased_14/6346.rst new file mode 100644 index 0000000000..4ca26b7492 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6346.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, orm, dataclasses + :tickets: 6346 + + Adjusted the declarative scan for dataclasses so that the inheritance + behavior of :func:`_orm.declared_attr` established on a mixin, when using + the new form of having it inside of a ``dataclasses.field()`` construct and + not actually a descriptor attribute on the class, correctly accommodates + the case when the target class to be mapped is a subclass of an existing + mapped class which has already mapped that :func:`_orm.declared_attr`, and + therefore should not be re-applied to this class. + diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index d9c464815b..4e2c3a8860 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -205,11 +205,10 @@ class declared_attr(interfaces._MappedAttribute, property): """ # noqa E501 - def __init__(self, fget, cascading=False, _is_dataclass=False): + def __init__(self, fget, cascading=False): super(declared_attr, self).__init__(fget) self.__doc__ = fget.__doc__ self._cascading = cascading - self._is_dataclass = _is_dataclass def __get__(desc, self, cls): # the declared_attr needs to make use of a cache that exists diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index b3444f26f4..f52827ad13 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -536,11 +536,24 @@ class _ClassScanMapperConfig(_MapperConfig): ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) else: - if obj._is_dataclass: - ret = obj.fget() - else: - + if is_dataclass: # access attribute using normal class access + # first, to see if it's been mapped on a + # superclass. note if the dataclasses.field() + # has "default", this value can be anything. + ret = getattr(cls, name, None) + + # so, if it's anything that's not ORM + # mapped, assume we should invoke the + # declared_attr + if not isinstance(ret, InspectionAttr): + ret = obj.fget() + else: + # access attribute using normal class access. + # if the declared attr already took place + # on a superclass that is mapped, then + # this is no longer a declared_attr, it will + # be the InstrumentedAttribute ret = getattr(cls, name) # correct for proxies created from hybrid_property @@ -988,11 +1001,9 @@ def _as_dc_declaredattr(field_metadata, sa_dataclass_metadata_key): decl_api = util.preloaded.orm_decl_api obj = field_metadata[sa_dataclass_metadata_key] if callable(obj) and not isinstance(obj, decl_api.declared_attr): - return decl_api.declared_attr(obj, _is_dataclass=True) - elif isinstance(obj, decl_api.declared_attr): - obj._is_dataclass = True + return decl_api.declared_attr(obj) + else: return obj - return obj class _DeferredMapperConfig(_ClassScanMapperConfig): diff --git a/test/orm/test_dataclasses_py3k.py b/test/orm/test_dataclasses_py3k.py index 56091505c3..2debc3ddf4 100644 --- a/test/orm/test_dataclasses_py3k.py +++ b/test/orm/test_dataclasses_py3k.py @@ -29,9 +29,6 @@ except ImportError: class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL): __requires__ = ("dataclasses",) - run_setup_classes = "each" - run_setup_mappers = "each" - @classmethod def define_tables(cls, metadata): Table( @@ -525,9 +522,6 @@ class FieldEmbeddedWMixinTest(FieldEmbeddedDeclarativeDataclassesTest): class FieldEmbeddedMixinWLambdaTest(fixtures.DeclarativeMappedTest): __requires__ = ("dataclasses",) - run_setup_classes = "each" - run_setup_mappers = "each" - @classmethod def setup_classes(cls): declarative = cls.DeclarativeBasic.registry.mapped @@ -554,6 +548,11 @@ class FieldEmbeddedMixinWLambdaTest(fixtures.DeclarativeMappedTest): }, ) + has_a_default: str = dataclasses.field( + default="some default", + metadata={"sa": lambda: Column(String(50))}, + ) + @declarative @dataclasses.dataclass class Widget(WidgetDC): @@ -566,11 +565,35 @@ class FieldEmbeddedMixinWLambdaTest(fixtures.DeclarativeMappedTest): default=None, metadata={"sa": Column(String(30), nullable=False)}, ) + __mapper_args__ = dict( polymorphic_on="type", polymorphic_identity="normal", ) + @declarative + @dataclasses.dataclass + class SpecialWidget(Widget): + __tablename__ = "special_widgets" + __sa_dataclass_metadata_key__ = "sa" + + special_widget_id: int = dataclasses.field( + init=False, + metadata={ + "sa": Column( + ForeignKey("widgets.widget_id"), primary_key=True + ) + }, + ) + + magic: bool = dataclasses.field( + default=False, metadata={"sa": Column(Boolean)} + ) + + __mapper_args__ = dict( + polymorphic_identity="special", + ) + @dataclasses.dataclass class AccountDC: @@ -631,9 +654,12 @@ class FieldEmbeddedMixinWLambdaTest(fixtures.DeclarativeMappedTest): cls.classes["Account"] = Account cls.classes["Widget"] = Widget cls.classes["User"] = User + cls.classes["SpecialWidget"] = SpecialWidget def test_setup(self): - Account, Widget, User = self.classes("Account", "Widget", "User") + Account, Widget, User, SpecialWidget = self.classes( + "Account", "Widget", "User", "SpecialWidget" + ) assert "account_id" in Widget.__table__.c assert list(Widget.__table__.c.account_id.foreign_keys)[0].references( @@ -641,11 +667,35 @@ class FieldEmbeddedMixinWLambdaTest(fixtures.DeclarativeMappedTest): ) assert inspect(Account).relationships.widgets.mapper is inspect(Widget) + assert "account_id" not in SpecialWidget.__table__.c + + assert "has_a_default" in Widget.__table__.c + assert "has_a_default" not in SpecialWidget.__table__.c + assert "account_id" in User.__table__.c assert list(User.__table__.c.account_id.foreign_keys)[0].references( Account.__table__ ) + def test_asdict_and_astuple_special_widget(self): + SpecialWidget = self.classes.SpecialWidget + widget = SpecialWidget(magic=True) + eq_( + dataclasses.asdict(widget), + { + "widget_id": None, + "account_id": None, + "has_a_default": "some default", + "name": None, + "special_widget_id": None, + "magic": True, + }, + ) + eq_( + dataclasses.astuple(widget), + (None, None, "some default", None, None, True), + ) + class FieldEmbeddedMixinWDeclaredAttrTest(FieldEmbeddedMixinWLambdaTest): __requires__ = ("dataclasses",) @@ -678,6 +728,11 @@ class FieldEmbeddedMixinWDeclaredAttrTest(FieldEmbeddedMixinWLambdaTest): }, ) + has_a_default: str = dataclasses.field( + default="some default", + metadata={"sa": declared_attr(lambda: Column(String(50)))}, + ) + @declarative @dataclasses.dataclass class Widget(WidgetDC): @@ -695,6 +750,29 @@ class FieldEmbeddedMixinWDeclaredAttrTest(FieldEmbeddedMixinWLambdaTest): polymorphic_identity="normal", ) + @declarative + @dataclasses.dataclass + class SpecialWidget(Widget): + __tablename__ = "special_widgets" + __sa_dataclass_metadata_key__ = "sa" + + special_widget_id: int = dataclasses.field( + init=False, + metadata={ + "sa": Column( + ForeignKey("widgets.widget_id"), primary_key=True + ) + }, + ) + + magic: bool = dataclasses.field( + default=False, metadata={"sa": Column(Boolean)} + ) + + __mapper_args__ = dict( + polymorphic_identity="special", + ) + @dataclasses.dataclass class AccountDC: @@ -757,14 +835,12 @@ class FieldEmbeddedMixinWDeclaredAttrTest(FieldEmbeddedMixinWLambdaTest): cls.classes["Account"] = Account cls.classes["Widget"] = Widget cls.classes["User"] = User + cls.classes["SpecialWidget"] = SpecialWidget class PropagationFromMixinTest(fixtures.TestBase): __requires__ = ("dataclasses",) - run_setup_classes = "each" - run_setup_mappers = "each" - def test_propagate_w_plain_mixin_col(self, run_test): @dataclasses.dataclass class CommonMixin: @@ -865,9 +941,6 @@ class PropagationFromMixinTest(fixtures.TestBase): class PropagationFromAbstractTest(fixtures.TestBase): __requires__ = ("dataclasses",) - run_setup_classes = "each" - run_setup_mappers = "each" - def test_propagate_w_plain_mixin_col(self, run_test): @dataclasses.dataclass class BaseType: