From: Mike Bayer Date: Mon, 25 Jan 2021 22:59:35 +0000 (-0500) Subject: Fill-out dataclass-related attr resolution X-Git-Tag: rel_1_4_0b2~24 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9205e9171cfd4b488be61228d8d53b0da1d49c19;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fill-out dataclass-related attr resolution Fixed issue where mixin attribute rules were not taking effect correctly for attributes pulled from dataclasses using the approach added in #5745. Fixes: #5876 Change-Id: I45099a42de1d9611791e72250fe0edc69bed684c --- diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index db6d274c86..db7dfebe4a 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -325,6 +325,94 @@ class _ClassScanMapperConfig(_MapperConfig): def before_configured(): self.cls.__declare_first__() + def _cls_attr_override_checker(self, cls): + """Produce a function that checks if a class has overridden an + attribute, taking SQLAlchemy-enabled dataclass fields into account. + + """ + sa_dataclass_metadata_key = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__", None + ) + + if sa_dataclass_metadata_key is None: + + def attribute_is_overridden(key, obj): + return getattr(cls, key) is not obj + + else: + + all_datacls_fields = { + f.name: f.metadata[sa_dataclass_metadata_key] + for f in util.dataclass_fields(cls) + if sa_dataclass_metadata_key in f.metadata + } + local_datacls_fields = { + f.name: f.metadata[sa_dataclass_metadata_key] + for f in util.local_dataclass_fields(cls) + if sa_dataclass_metadata_key in f.metadata + } + + absent = object() + + def attribute_is_overridden(key, obj): + # this function likely has some failure modes still if + # someone is doing a deep mixing of the same attribute + # name as plain Python attribute vs. dataclass field. + + ret = local_datacls_fields.get(key, absent) + + if ret is obj: + return False + elif ret is not absent: + return True + + ret = getattr(cls, key, obj) + + if ret is obj: + return False + elif ret is not absent: + return True + + ret = all_datacls_fields.get(key, absent) + + if ret is obj: + return False + elif ret is not absent: + return True + + # can't find another attribute + return False + + return attribute_is_overridden + + def _cls_attr_resolver(self, cls): + """produce a function to iterate the "attributes" of a class, + adjusting for SQLAlchemy fields embedded in dataclass fields. + + """ + sa_dataclass_metadata_key = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__", None + ) + + if sa_dataclass_metadata_key is None: + + def local_attributes_for_class(): + for name, obj in vars(cls).items(): + yield name, obj + + else: + + def local_attributes_for_class(): + for name, obj in vars(cls).items(): + yield name, obj + for field in util.local_dataclass_fields(cls): + if sa_dataclass_metadata_key in field.metadata: + yield field.name, field.metadata[ + sa_dataclass_metadata_key + ] + + return local_attributes_for_class + def _scan_attributes(self): cls = self.cls dict_ = self.dict_ @@ -333,9 +421,9 @@ class _ClassScanMapperConfig(_MapperConfig): table_args = inherited_table_args = None tablename = None - for base in cls.__mro__: + attribute_is_overridden = self._cls_attr_override_checker(self.cls) - sa_dataclass_metadata_key = None + for base in cls.__mro__: class_mapped = ( base is not cls @@ -345,25 +433,14 @@ class _ClassScanMapperConfig(_MapperConfig): ) ) - if sa_dataclass_metadata_key is None: - sa_dataclass_metadata_key = _get_immediate_cls_attr( - base, "__sa_dataclass_metadata_key__", None - ) - - def attributes_for_class(cls): - for name, obj in vars(cls).items(): - yield name, obj - if sa_dataclass_metadata_key: - for field in util.dataclass_fields(cls): - if sa_dataclass_metadata_key in field.metadata: - yield field.name, field.metadata[ - sa_dataclass_metadata_key - ] + local_attributes_for_class = self._cls_attr_resolver(base) if not class_mapped and base is not cls: - self._produce_column_copies(attributes_for_class, base) + self._produce_column_copies( + local_attributes_for_class, attribute_is_overridden + ) - for name, obj in attributes_for_class(base): + for name, obj in local_attributes_for_class(): if name == "__mapper_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -471,6 +548,15 @@ class _ClassScanMapperConfig(_MapperConfig): else: self._warn_for_decl_attributes(base, name, obj) elif name not in dict_ or dict_[name] is not obj: + # here, we are definitely looking at the target class + # and not a superclass. this is currently a + # dataclass-only path. if the name is only + # a dataclass field and isn't in local cls.__dict__, + # put the object there. + + # assert that the dataclass-enabled resolver agrees + # with what we are seeing + assert not attribute_is_overridden(name, obj) dict_[name] = obj if inherited_table_args and not tablename: @@ -489,14 +575,17 @@ class _ClassScanMapperConfig(_MapperConfig): % (key, cls) ) - def _produce_column_copies(self, attributes_for_class, base): + def _produce_column_copies( + self, attributes_for_class, attribute_is_overridden + ): cls = self.cls dict_ = self.dict_ column_copies = self.column_copies # copy mixin columns to the mapped class - for name, obj in attributes_for_class(base): + + for name, obj in attributes_for_class(): if isinstance(obj, Column): - if getattr(cls, name) is not obj: + if attribute_is_overridden(name, obj): # if column has been overridden # (like by the InstrumentedAttribute of the # superclass), skip diff --git a/lib/sqlalchemy/testing/fixtures.py b/lib/sqlalchemy/testing/fixtures.py index 62dd9040ee..4b76e6d88f 100644 --- a/lib/sqlalchemy/testing/fixtures.py +++ b/lib/sqlalchemy/testing/fixtures.py @@ -552,6 +552,7 @@ class DeclarativeMappedTest(MappedTest): metaclass=FindFixtureDeclarative, cls=DeclarativeBasic, ) + cls.DeclarativeBasic = _DeclBase # sets up cls.Basic which is helpful for things like composite diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 5f8788a6e8..2d86b8b633 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -66,6 +66,7 @@ from .compat import int_types # noqa from .compat import iterbytes # noqa from .compat import itertools_filter # noqa from .compat import itertools_filterfalse # noqa +from .compat import local_dataclass_fields # noqa from .compat import namedtuple # noqa from .compat import next # noqa from .compat import nullcontext # noqa diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 1eed2c3afe..5b7a3eb9fb 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -425,17 +425,37 @@ if py37: import dataclasses def dataclass_fields(cls): + """Return a sequence of all dataclasses.Field objects associated + with a class.""" + if dataclasses.is_dataclass(cls): return dataclasses.fields(cls) else: return [] + def local_dataclass_fields(cls): + """Return a sequence of all dataclasses.Field objects associated with + a class, excluding those that originate from a superclass.""" + + if dataclasses.is_dataclass(cls): + super_fields = set() + for sup in cls.__bases__: + super_fields.update(dataclass_fields(sup)) + return [ + f for f in dataclasses.fields(cls) if f not in super_fields + ] + else: + return [] + else: def dataclass_fields(cls): return [] + def local_dataclass_fields(cls): + return [] + def raise_from_cause(exception, exc_info=None): r"""legacy. use raise\_()""" diff --git a/test/orm/test_dataclasses_py3k.py b/test/orm/test_dataclasses_py3k.py index a4b5e4c83d..51c76c125a 100644 --- a/test/orm/test_dataclasses_py3k.py +++ b/test/orm/test_dataclasses_py3k.py @@ -6,12 +6,16 @@ from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy.orm import clear_mappers +from sqlalchemy.orm import declared_attr from sqlalchemy.orm import mapper from sqlalchemy.orm import registry as declarative_registry +from sqlalchemy.orm import registry from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -171,14 +175,14 @@ class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL): assert Widget("Foo") != Widget("Bar") assert Widget("Foo") != SpecialWidget("Foo") - def test_asdict_and_astuple(self): + def test_asdict_and_astuple_widget(self): Widget = self.classes.Widget - SpecialWidget = self.classes.SpecialWidget - widget = Widget("Foo") eq_(dataclasses.asdict(widget), {"name": "Foo"}) eq_(dataclasses.astuple(widget), ("Foo",)) + def test_asdict_and_astuple_special_widget(self): + SpecialWidget = self.classes.SpecialWidget widget = SpecialWidget("Bar", magic=True) eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True}) eq_(dataclasses.astuple(widget), ("Bar", True)) @@ -187,11 +191,11 @@ class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL): Account = self.classes.Account account = self.data_fixture() - with Session(testing.db) as session: + with fixture_session() as session: session.add(account) session.commit() - with Session(testing.db) as session: + with fixture_session() as session: a = session.query(Account).get(42) self.check_data_fixture(a) @@ -373,14 +377,229 @@ class FieldEmbeddedDeclarativeDataclassesTest( def define_tables(cls, metadata): pass - def test_asdict_and_astuple(self): + def test_asdict_and_astuple_widget(self): Widget = self.classes.Widget - SpecialWidget = self.classes.SpecialWidget widget = Widget("Foo") eq_(dataclasses.asdict(widget), {"name": "Foo"}) eq_(dataclasses.astuple(widget), ("Foo",)) + def test_asdict_and_astuple_special_widget(self): + SpecialWidget = self.classes.SpecialWidget widget = SpecialWidget("Bar", magic=True) eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True}) eq_(dataclasses.astuple(widget), ("Bar", True)) + + +class FieldEmbeddedWMixinTest(FieldEmbeddedDeclarativeDataclassesTest): + __requires__ = ("dataclasses",) + + @classmethod + def setup_classes(cls): + declarative = cls.DeclarativeBasic.registry.mapped + + @dataclasses.dataclass + class SurrogateWidgetPK: + + __sa_dataclass_metadata_key__ = "sa" + + widget_id: int = dataclasses.field( + init=False, + metadata={"sa": Column(Integer, primary_key=True)}, + ) + + @declarative + @dataclasses.dataclass + class Widget(SurrogateWidgetPK): + __tablename__ = "widgets" + __sa_dataclass_metadata_key__ = "sa" + + account_id = Column( + Integer, + ForeignKey("accounts.account_id"), + nullable=False, + ) + type = Column(String(30), nullable=False) + + name: Optional[str] = dataclasses.field( + default=None, + metadata={"sa": Column(String(30), nullable=False)}, + ) + __mapper_args__ = dict( + polymorphic_on="type", + polymorphic_identity="normal", + ) + + @declarative + @dataclasses.dataclass + class SpecialWidget(Widget): + __sa_dataclass_metadata_key__ = "sa" + + magic: bool = dataclasses.field( + default=False, metadata={"sa": Column(Boolean)} + ) + + __mapper_args__ = dict( + polymorphic_identity="special", + ) + + @dataclasses.dataclass + class SurrogateAccountPK: + + __sa_dataclass_metadata_key__ = "sa" + + account_id = Column( + "we_dont_want_to_use_this", Integer, primary_key=True + ) + + @declarative + @dataclasses.dataclass + class Account(SurrogateAccountPK): + __tablename__ = "accounts" + __sa_dataclass_metadata_key__ = "sa" + + account_id: int = dataclasses.field( + metadata={"sa": Column(Integer, primary_key=True)}, + ) + widgets: List[Widget] = dataclasses.field( + default_factory=list, metadata={"sa": relationship("Widget")} + ) + widget_count: int = dataclasses.field( + init=False, + metadata={ + "sa": Column("widget_count", Integer, nullable=False) + }, + ) + + def __post_init__(self): + self.widget_count = len(self.widgets) + + def add_widget(self, widget: Widget): + self.widgets.append(widget) + self.widget_count += 1 + + cls.classes.Account = Account + cls.classes.Widget = Widget + cls.classes.SpecialWidget = SpecialWidget + + def check_widget_dataclass(self, obj): + assert dataclasses.is_dataclass(obj) + ( + id_, + name, + ) = dataclasses.fields(obj) + eq_(name.name, "name") + eq_(id_.name, "widget_id") + + def check_special_widget_dataclass(self, obj): + assert dataclasses.is_dataclass(obj) + id_, name, magic = dataclasses.fields(obj) + eq_(id_.name, "widget_id") + eq_(name.name, "name") + eq_(magic.name, "magic") + + def test_asdict_and_astuple_widget(self): + Widget = self.classes.Widget + + widget = Widget("Foo") + eq_(dataclasses.asdict(widget), {"name": "Foo", "widget_id": None}) + eq_( + dataclasses.astuple(widget), + ( + None, + "Foo", + ), + ) + + def test_asdict_and_astuple_special_widget(self): + SpecialWidget = self.classes.SpecialWidget + widget = SpecialWidget("Bar", magic=True) + eq_( + dataclasses.asdict(widget), + {"name": "Bar", "magic": True, "widget_id": None}, + ) + eq_(dataclasses.astuple(widget), (None, "Bar", True)) + + +class PropagationBlockTest(fixtures.TestBase): + __requires__ = ("dataclasses",) + + run_setup_classes = "each" + run_setup_mappers = "each" + + def test_propagate_w_plain_mixin_col(self, run_test): + @dataclasses.dataclass + class CommonMixin: + __sa_dataclass_metadata_key__ = "sa" + + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + __table_args__ = {"mysql_engine": "InnoDB"} + timestamp = Column(Integer) + + run_test(CommonMixin) + + def test_propagate_w_field_mixin_col(self, run_test): + @dataclasses.dataclass + class CommonMixin: + __sa_dataclass_metadata_key__ = "sa" + + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + __table_args__ = {"mysql_engine": "InnoDB"} + + timestamp: int = dataclasses.field( + init=False, + metadata={"sa": Column(Integer, nullable=False)}, + ) + + run_test(CommonMixin) + + @testing.fixture() + def run_test(self): + def go(CommonMixin): + declarative = registry().mapped + + @declarative + @dataclasses.dataclass + class BaseType(CommonMixin): + + discriminator = Column("type", String(50)) + __mapper_args__ = dict(polymorphic_on=discriminator) + id = Column(Integer, primary_key=True) + value = Column(Integer()) + + @declarative + @dataclasses.dataclass + class Single(BaseType): + + __tablename__ = None + __mapper_args__ = dict(polymorphic_identity="type1") + + @declarative + @dataclasses.dataclass + class Joined(BaseType): + + __mapper_args__ = dict(polymorphic_identity="type2") + id = Column( + Integer, ForeignKey("basetype.id"), primary_key=True + ) + + eq_(BaseType.__table__.name, "basetype") + eq_( + list(BaseType.__table__.c.keys()), + ["timestamp", "type", "id", "value"], + ) + eq_(BaseType.__table__.kwargs, {"mysql_engine": "InnoDB"}) + assert Single.__table__ is BaseType.__table__ + eq_(Joined.__table__.name, "joined") + eq_(list(Joined.__table__.c.keys()), ["id"]) + eq_(Joined.__table__.kwargs, {"mysql_engine": "InnoDB"}) + + yield go + + clear_mappers()