From: Mike Bayer Date: Thu, 18 Mar 2021 19:07:03 +0000 (-0400) Subject: Adjust dataclass rules to account for field w/ default X-Git-Tag: rel_1_4_2~11^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ce2f28c37e0a2f2aa3b4a404ee190cdc00b8b918;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Adjust dataclass rules to account for field w/ default Fixed issue in new ORM dataclasses functionality where dataclass fields on an abstract base or mixin that contained column or other mapping constructs would not be mapped if they also included a "default" key within the dataclasses.field() object. Fixes: #6093 Change-Id: I628086ceb48ab1dd0702f239cd12be74074f58f1 --- diff --git a/doc/build/changelog/unreleased_14/6093.rst b/doc/build/changelog/unreleased_14/6093.rst new file mode 100644 index 0000000000..95e4af6f2c --- /dev/null +++ b/doc/build/changelog/unreleased_14/6093.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm, dataclasses + :tickets: 6093 + + Fixed issue in new ORM dataclasses functionality where dataclass fields on + an abstract base or mixin that contained column or other mapping constructs + would not be mapped if they also included a "default" key within the + dataclasses.field() object. + diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index a21af192e0..0a73288fd6 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -15,6 +15,7 @@ from sqlalchemy.orm import instrumentation from . import clsregistry from . import exc as orm_exc from . import mapper as mapperlib +from .attributes import InstrumentedAttribute from .attributes import QueryableAttribute from .base import _is_mapped_class from .base import InspectionAttr @@ -366,18 +367,24 @@ class _ClassScanMapperConfig(_MapperConfig): elif ret is not absent: return True + all_field = all_datacls_fields.get(key, absent) + ret = getattr(cls, key, obj) if ret is obj: return False - elif ret is not absent: - return True - ret = all_datacls_fields.get(key, absent) + # for dataclasses, this could be the + # 'default' of the field. so filter more specifically + # for an already-mapped InstrumentedAttribute + if ret is not absent and isinstance( + ret, InstrumentedAttribute + ): + return True - if ret is obj: + if all_field is obj: return False - elif ret is not absent: + elif all_field is not absent: return True # can't find another attribute @@ -401,15 +408,18 @@ class _ClassScanMapperConfig(_MapperConfig): yield name, obj else: + field_names = set() def local_attributes_for_class(): - for name, obj in vars(cls).items(): - yield name, obj for field in util.local_dataclass_fields(cls): if sa_dataclass_metadata_key in field.metadata: + field_names.add(field.name) yield field.name, field.metadata[ sa_dataclass_metadata_key ] + for name, obj in vars(cls).items(): + if name not in field_names: + yield name, obj return local_attributes_for_class diff --git a/test/orm/test_dataclasses_py3k.py b/test/orm/test_dataclasses_py3k.py index 51c76c125a..ef1c12050e 100644 --- a/test/orm/test_dataclasses_py3k.py +++ b/test/orm/test_dataclasses_py3k.py @@ -521,7 +521,7 @@ class FieldEmbeddedWMixinTest(FieldEmbeddedDeclarativeDataclassesTest): eq_(dataclasses.astuple(widget), (None, "Bar", True)) -class PropagationBlockTest(fixtures.TestBase): +class PropagationFromMixinTest(fixtures.TestBase): __requires__ = ("dataclasses",) run_setup_classes = "each" @@ -559,6 +559,25 @@ class PropagationBlockTest(fixtures.TestBase): run_test(CommonMixin) + def test_propagate_w_field_mixin_col_and_default(self, run_test): + @dataclasses.dataclass + class CommonMixin: + __sa_dataclass_metadata_key__ = "sa" + + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + __table_args__ = {"mysql_engine": "InnoDB"} + + timestamp: int = dataclasses.field( + init=False, + default=12, + metadata={"sa": Column(Integer, nullable=False)}, + ) + + run_test(CommonMixin) + @testing.fixture() def run_test(self): def go(CommonMixin): @@ -603,3 +622,99 @@ class PropagationBlockTest(fixtures.TestBase): yield go clear_mappers() + + +class PropagationFromAbstractTest(fixtures.TestBase): + __requires__ = ("dataclasses",) + + run_setup_classes = "each" + run_setup_mappers = "each" + + def test_propagate_w_plain_mixin_col(self, run_test): + @dataclasses.dataclass + class BaseType: + __sa_dataclass_metadata_key__ = "sa" + + __table_args__ = {"mysql_engine": "InnoDB"} + + discriminator: str = Column("type", String(50)) + __mapper_args__ = dict(polymorphic_on=discriminator) + id: int = Column(Integer, primary_key=True) + value: int = Column(Integer()) + + timestamp: int = Column(Integer) + + run_test(BaseType) + + def test_propagate_w_field_mixin_col(self, run_test): + @dataclasses.dataclass + class BaseType: + __sa_dataclass_metadata_key__ = "sa" + + __table_args__ = {"mysql_engine": "InnoDB"} + + discriminator: str = Column("type", String(50)) + __mapper_args__ = dict(polymorphic_on=discriminator) + id: int = Column(Integer, primary_key=True) + value: int = Column(Integer()) + + timestamp: int = dataclasses.field( + init=False, + metadata={"sa": Column(Integer, nullable=False)}, + ) + + run_test(BaseType) + + def test_propagate_w_field_mixin_col_and_default(self, run_test): + @dataclasses.dataclass + class BaseType: + __sa_dataclass_metadata_key__ = "sa" + + __table_args__ = {"mysql_engine": "InnoDB"} + + discriminator: str = Column("type", String(50)) + __mapper_args__ = dict(polymorphic_on=discriminator) + id: int = Column(Integer, primary_key=True) + value: int = Column(Integer()) + + timestamp: int = dataclasses.field( + init=False, + default=None, + metadata={"sa": Column(Integer, nullable=False)}, + ) + + run_test(BaseType) + + @testing.fixture() + def run_test(self): + def go(BaseType): + declarative = registry().mapped + + @declarative + @dataclasses.dataclass + class Single(BaseType): + + __tablename__ = "single" + __mapper_args__ = dict(polymorphic_identity="type1") + + @declarative + @dataclasses.dataclass + class Joined(Single): + __tablename__ = "joined" + __mapper_args__ = dict(polymorphic_identity="type2") + id = Column(Integer, ForeignKey("single.id"), primary_key=True) + + eq_(Single.__table__.name, "single") + eq_( + list(Single.__table__.c.keys()), + ["type", "id", "value", "timestamp"], + ) + eq_(Single.__table__.kwargs, {"mysql_engine": "InnoDB"}) + + eq_(Joined.__table__.name, "joined") + eq_(list(Joined.__table__.c.keys()), ["id"]) + eq_(Joined.__table__.kwargs, {"mysql_engine": "InnoDB"}) + + yield go + + clear_mappers()