From: Mike Bayer Date: Wed, 26 Oct 2022 17:27:21 +0000 (-0400) Subject: ensure inherited mapper attrs not interpreted as plain dataclass fields X-Git-Tag: rel_2_0_0b3~24^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=73be84ae46473703dcf7b8d39e9666496fb07c8f;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git ensure inherited mapper attrs not interpreted as plain dataclass fields Fixed issue in new dataclass mapping feature where a column declared on the decalrative base / abstract base / mixin would leak into the constructor for an inheriting subclass under some circumstances. Fixes: #8718 Change-Id: Ic519acf239e2f80541516f10995991cbbbed00bd --- diff --git a/doc/build/changelog/unreleased_20/8718.rst b/doc/build/changelog/unreleased_20/8718.rst new file mode 100644 index 0000000000..7aedaaa2cf --- /dev/null +++ b/doc/build/changelog/unreleased_20/8718.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, orm, declarative + :tickets: 8718 + + Fixed issue in new dataclass mapping feature where a column declared on the + decalrative base / abstract base / mixin would leak into the constructor + for an inheriting subclass under some circumstances. diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 4e79ecc6fd..4e02e589b9 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -772,7 +772,6 @@ class _ClassScanMapperConfig(_MapperConfig): annotation, is_dataclass_field, ) in local_attributes_for_class(): - if re.match(r"^__.+__$", name): if name == "__mapper_args__": check_decl = _check_declared_props_nocascade( @@ -825,6 +824,7 @@ class _ClassScanMapperConfig(_MapperConfig): "not applying to subclass %s." % (base.__name__, name, base, cls) ) + continue elif base is not cls: # we're a mixin, abstract base, or something that is @@ -990,10 +990,15 @@ class _ClassScanMapperConfig(_MapperConfig): _AttributeOptions._get_arguments_for_make_dataclass( key, anno, + mapped_container, self.collected_attributes.get(key, _NoArg.NO_ARG), ) - for key, anno in ( - (key, mapped_anno if mapped_anno else raw_anno) + for key, anno, mapped_container in ( + ( + key, + mapped_anno if mapped_anno else raw_anno, + mapped_container, + ) for key, ( raw_anno, mapped_container, @@ -1003,7 +1008,6 @@ class _ClassScanMapperConfig(_MapperConfig): ) in self.collected_annotations.items() ) ] - annotations = {} defaults = {} for item in field_list: @@ -1139,7 +1143,6 @@ class _ClassScanMapperConfig(_MapperConfig): # copy mixin columns to the mapped class for name, obj, annotation, is_dataclass in attributes_for_class(): - if ( not fixed_table and obj is None @@ -1154,14 +1157,16 @@ class _ClassScanMapperConfig(_MapperConfig): elif isinstance(obj, (Column, MappedColumn)): - obj = self._collect_annotation(name, annotation, True, obj) - if attribute_is_overridden(name, obj): # if column has been overridden # (like by the InstrumentedAttribute of the - # superclass), skip + # superclass), skip. don't collect the annotation + # either (issue #8718) continue - elif name not in dict_ and not ( + + obj = self._collect_annotation(name, annotation, True, obj) + + if name not in dict_ and not ( "__table__" in dict_ and (getattr(obj, "name", None) or name) in dict_["__table__"].c diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 9903c5f4a4..1747bfd9b2 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -213,7 +213,11 @@ class _AttributeOptions(NamedTuple): @classmethod def _get_arguments_for_make_dataclass( - cls, key: str, annotation: Type[Any], elem: _T + cls, + key: str, + annotation: Type[Any], + mapped_container: Optional[Any], + elem: _T, ) -> Union[ Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]] ]: @@ -229,7 +233,21 @@ class _AttributeOptions(NamedTuple): elif elem is not _NoArg.NO_ARG: # why is typing not erroring on this? return (key, annotation, elem) + elif mapped_container is not None: + # it's Mapped[], but there's no "element", which means declarative + # did not actually do anything for this field. this shouldn't + # happen. + # previously, this would occur because _scan_attributes would + # skip a field that's on an already mapped superclass, but it + # would still include it in the annotations, leading + # to issue #8718 + + assert False, "Mapped[] received without a mapping declaration" + else: + # plain dataclass field, not mapped. Is only possible + # if __allow_unmapped__ is set up. I can see this mode causing + # problems... return (key, annotation) diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index ef62b7cb24..86c963ec68 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -406,6 +406,33 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): }, ) + def test_allow_unmapped_fields_wo_mapped_or_dc_w_inherits( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class A(dc_decl_base): + __tablename__ = "a" + __allow_unmapped__ = True + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: str + ctrl_one: str = dataclasses.field() + some_field: int = dataclasses.field(default=5) + + class B(A): + b_data: Mapped[str] = mapped_column(default="bd") + + b1 = B(data="data", ctrl_one="ctrl_one", some_field=5, b_data="x") + eq_( + dataclasses.asdict(b1), + { + "ctrl_one": "ctrl_one", + "data": "data", + "id": None, + "some_field": 5, + "b_data": "x", + }, + ) + def test_integrated_dc(self, dc_decl_base: Type[MappedAsDataclass]): """We will be telling users "this is a dataclass that is also mapped". Therefore, they will want *any* kind of attribute to do what @@ -1186,6 +1213,138 @@ class DataclassArgsTest(fixtures.TestBase): eq_(prop._attribute_options, exp) +class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): + """tests for #8718""" + + __dialect__ = "default" + + @testing.fixture + def model(self): + def go(use_mixin, use_inherits, mad_setup): + + if use_mixin: + + if mad_setup == "dc, mad": + + class BaseEntity(DeclarativeBase, MappedAsDataclass): + pass + + elif mad_setup == "mad, dc": + + class BaseEntity(MappedAsDataclass, DeclarativeBase): + pass + + elif mad_setup == "subclass": + + class BaseEntity(DeclarativeBase): + pass + + class IdMixin: + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + if mad_setup == "subclass": + + class A(IdMixin, MappedAsDataclass, BaseEntity): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + else: + + class A(IdMixin, BaseEntity): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + else: + + if mad_setup == "dc, mad": + + class BaseEntity(DeclarativeBase, MappedAsDataclass): + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + elif mad_setup == "mad, dc": + + class BaseEntity(MappedAsDataclass, DeclarativeBase): + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + elif mad_setup == "subclass": + + class BaseEntity(DeclarativeBase): + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + if mad_setup == "subclass": + + class A(MappedAsDataclass, BaseEntity): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + else: + + class A(BaseEntity): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + if use_inherits: + + class B(A): + __mapper_args__ = { + "polymorphic_identity": "b", + } + b_data: Mapped[str] = mapped_column(String, init=False) + + return B + else: + return A + + yield go + + @testing.combinations("inherits", "plain", argnames="use_inherits") + @testing.combinations("mixin", "base", argnames="use_mixin") + @testing.combinations( + "mad, dc", "dc, mad", "subclass", argnames="mad_setup" + ) + def test_mapping(self, model, use_inherits, use_mixin, mad_setup): + target_cls = model( + use_inherits=use_inherits == "inherits", + use_mixin=use_mixin == "mixin", + mad_setup=mad_setup, + ) + + obj = target_cls() + assert "id" not in obj.__dict__ + + class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default"