]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure inherited mapper attrs not interpreted as plain dataclass fields
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 26 Oct 2022 17:27:21 +0000 (13:27 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 26 Oct 2022 18:31:28 +0000 (14:31 -0400)
Fixed issue in new dataclass mapping feature where a column declared on the
decalrative base / abstract base / mixin would leak into the constructor
for an inheriting subclass under some circumstances.

Fixes: #8718
Change-Id: Ic519acf239e2f80541516f10995991cbbbed00bd

doc/build/changelog/unreleased_20/8718.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/interfaces.py
test/orm/declarative/test_dc_transforms.py

diff --git a/doc/build/changelog/unreleased_20/8718.rst b/doc/build/changelog/unreleased_20/8718.rst
new file mode 100644 (file)
index 0000000..7aedaaa
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm, declarative
+    :tickets: 8718
+
+    Fixed issue in new dataclass mapping feature where a column declared on the
+    decalrative base / abstract base / mixin would leak into the constructor
+    for an inheriting subclass under some circumstances.
index 4e79ecc6fd190aaa7144669d43019d815eac46d4..4e02e589b9cc2e7144e6e736d9f9f4ba487b6684 100644 (file)
@@ -772,7 +772,6 @@ class _ClassScanMapperConfig(_MapperConfig):
                 annotation,
                 is_dataclass_field,
             ) in local_attributes_for_class():
-
                 if re.match(r"^__.+__$", name):
                     if name == "__mapper_args__":
                         check_decl = _check_declared_props_nocascade(
@@ -825,6 +824,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                             "not applying to subclass %s."
                             % (base.__name__, name, base, cls)
                         )
+
                     continue
                 elif base is not cls:
                     # we're a mixin, abstract base, or something that is
@@ -990,10 +990,15 @@ class _ClassScanMapperConfig(_MapperConfig):
             _AttributeOptions._get_arguments_for_make_dataclass(
                 key,
                 anno,
+                mapped_container,
                 self.collected_attributes.get(key, _NoArg.NO_ARG),
             )
-            for key, anno in (
-                (key, mapped_anno if mapped_anno else raw_anno)
+            for key, anno, mapped_container in (
+                (
+                    key,
+                    mapped_anno if mapped_anno else raw_anno,
+                    mapped_container,
+                )
                 for key, (
                     raw_anno,
                     mapped_container,
@@ -1003,7 +1008,6 @@ class _ClassScanMapperConfig(_MapperConfig):
                 ) in self.collected_annotations.items()
             )
         ]
-
         annotations = {}
         defaults = {}
         for item in field_list:
@@ -1139,7 +1143,6 @@ class _ClassScanMapperConfig(_MapperConfig):
         # copy mixin columns to the mapped class
 
         for name, obj, annotation, is_dataclass in attributes_for_class():
-
             if (
                 not fixed_table
                 and obj is None
@@ -1154,14 +1157,16 @@ class _ClassScanMapperConfig(_MapperConfig):
 
             elif isinstance(obj, (Column, MappedColumn)):
 
-                obj = self._collect_annotation(name, annotation, True, obj)
-
                 if attribute_is_overridden(name, obj):
                     # if column has been overridden
                     # (like by the InstrumentedAttribute of the
-                    # superclass), skip
+                    # superclass), skip.  don't collect the annotation
+                    # either (issue #8718)
                     continue
-                elif name not in dict_ and not (
+
+                obj = self._collect_annotation(name, annotation, True, obj)
+
+                if name not in dict_ and not (
                     "__table__" in dict_
                     and (getattr(obj, "name", None) or name)
                     in dict_["__table__"].c
index 9903c5f4a42d6814759b4b862b112a04e37ecfef..1747bfd9b2a905afe70a4f8d224cfbfbdd62c547 100644 (file)
@@ -213,7 +213,11 @@ class _AttributeOptions(NamedTuple):
 
     @classmethod
     def _get_arguments_for_make_dataclass(
-        cls, key: str, annotation: Type[Any], elem: _T
+        cls,
+        key: str,
+        annotation: Type[Any],
+        mapped_container: Optional[Any],
+        elem: _T,
     ) -> Union[
         Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]]
     ]:
@@ -229,7 +233,21 @@ class _AttributeOptions(NamedTuple):
         elif elem is not _NoArg.NO_ARG:
             # why is typing not erroring on this?
             return (key, annotation, elem)
+        elif mapped_container is not None:
+            # it's Mapped[], but there's no "element", which means declarative
+            # did not actually do anything for this field.  this shouldn't
+            # happen.
+            # previously, this would occur because _scan_attributes would
+            # skip a field that's on an already mapped superclass, but it
+            # would still include it in the annotations, leading
+            # to issue #8718
+
+            assert False, "Mapped[] received without a mapping declaration"
+
         else:
+            # plain dataclass field, not mapped.  Is only possible
+            # if __allow_unmapped__ is set up.  I can see this mode causing
+            # problems...
             return (key, annotation)
 
 
index ef62b7cb245024492e1a7ac3cd666ff9361a13bf..86c963ec68866db4681b2007c681cb4a8a673434 100644 (file)
@@ -406,6 +406,33 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
             },
         )
 
+    def test_allow_unmapped_fields_wo_mapped_or_dc_w_inherits(
+        self, dc_decl_base: Type[MappedAsDataclass]
+    ):
+        class A(dc_decl_base):
+            __tablename__ = "a"
+            __allow_unmapped__ = True
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: str
+            ctrl_one: str = dataclasses.field()
+            some_field: int = dataclasses.field(default=5)
+
+        class B(A):
+            b_data: Mapped[str] = mapped_column(default="bd")
+
+        b1 = B(data="data", ctrl_one="ctrl_one", some_field=5, b_data="x")
+        eq_(
+            dataclasses.asdict(b1),
+            {
+                "ctrl_one": "ctrl_one",
+                "data": "data",
+                "id": None,
+                "some_field": 5,
+                "b_data": "x",
+            },
+        )
+
     def test_integrated_dc(self, dc_decl_base: Type[MappedAsDataclass]):
         """We will be telling users "this is a dataclass that is also
         mapped". Therefore, they will want *any* kind of attribute to do what
@@ -1186,6 +1213,138 @@ class DataclassArgsTest(fixtures.TestBase):
             eq_(prop._attribute_options, exp)
 
 
+class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
+    """tests for #8718"""
+
+    __dialect__ = "default"
+
+    @testing.fixture
+    def model(self):
+        def go(use_mixin, use_inherits, mad_setup):
+
+            if use_mixin:
+
+                if mad_setup == "dc, mad":
+
+                    class BaseEntity(DeclarativeBase, MappedAsDataclass):
+                        pass
+
+                elif mad_setup == "mad, dc":
+
+                    class BaseEntity(MappedAsDataclass, DeclarativeBase):
+                        pass
+
+                elif mad_setup == "subclass":
+
+                    class BaseEntity(DeclarativeBase):
+                        pass
+
+                class IdMixin:
+                    id: Mapped[int] = mapped_column(
+                        primary_key=True, init=False
+                    )
+
+                if mad_setup == "subclass":
+
+                    class A(IdMixin, MappedAsDataclass, BaseEntity):
+                        __mapper_args__ = {
+                            "polymorphic_on": "type",
+                            "polymorphic_identity": "a",
+                        }
+
+                        __tablename__ = "a"
+                        type: Mapped[str] = mapped_column(String, init=False)
+                        data: Mapped[str] = mapped_column(String, init=False)
+
+                else:
+
+                    class A(IdMixin, BaseEntity):
+                        __mapper_args__ = {
+                            "polymorphic_on": "type",
+                            "polymorphic_identity": "a",
+                        }
+
+                        __tablename__ = "a"
+                        type: Mapped[str] = mapped_column(String, init=False)
+                        data: Mapped[str] = mapped_column(String, init=False)
+
+            else:
+
+                if mad_setup == "dc, mad":
+
+                    class BaseEntity(DeclarativeBase, MappedAsDataclass):
+                        id: Mapped[int] = mapped_column(
+                            primary_key=True, init=False
+                        )
+
+                elif mad_setup == "mad, dc":
+
+                    class BaseEntity(MappedAsDataclass, DeclarativeBase):
+                        id: Mapped[int] = mapped_column(
+                            primary_key=True, init=False
+                        )
+
+                elif mad_setup == "subclass":
+
+                    class BaseEntity(DeclarativeBase):
+                        id: Mapped[int] = mapped_column(
+                            primary_key=True, init=False
+                        )
+
+                if mad_setup == "subclass":
+
+                    class A(MappedAsDataclass, BaseEntity):
+                        __mapper_args__ = {
+                            "polymorphic_on": "type",
+                            "polymorphic_identity": "a",
+                        }
+
+                        __tablename__ = "a"
+                        type: Mapped[str] = mapped_column(String, init=False)
+                        data: Mapped[str] = mapped_column(String, init=False)
+
+                else:
+
+                    class A(BaseEntity):
+                        __mapper_args__ = {
+                            "polymorphic_on": "type",
+                            "polymorphic_identity": "a",
+                        }
+
+                        __tablename__ = "a"
+                        type: Mapped[str] = mapped_column(String, init=False)
+                        data: Mapped[str] = mapped_column(String, init=False)
+
+            if use_inherits:
+
+                class B(A):
+                    __mapper_args__ = {
+                        "polymorphic_identity": "b",
+                    }
+                    b_data: Mapped[str] = mapped_column(String, init=False)
+
+                return B
+            else:
+                return A
+
+        yield go
+
+    @testing.combinations("inherits", "plain", argnames="use_inherits")
+    @testing.combinations("mixin", "base", argnames="use_mixin")
+    @testing.combinations(
+        "mad, dc", "dc, mad", "subclass", argnames="mad_setup"
+    )
+    def test_mapping(self, model, use_inherits, use_mixin, mad_setup):
+        target_cls = model(
+            use_inherits=use_inherits == "inherits",
+            use_mixin=use_mixin == "mixin",
+            mad_setup=mad_setup,
+        )
+
+        obj = target_cls()
+        assert "id" not in obj.__dict__
+
+
 class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"