]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
collect annotation earlier for mapped_column present
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 20 Oct 2022 16:05:33 +0000 (12:05 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 20 Oct 2022 16:09:32 +0000 (12:09 -0400)
Fixed issue with new dataclass mapping feature where arguments passed to
the dataclasses API could sometimes be mis-ordered when dealing with mixins
that override :func:`_orm.mapped_column` declarations, leading to
initializer problems.

the change made here is specific to the test case given which regards
mapped_column() most specifically.   cases that involve relationship()
etc. are not tested here, however mapped_column() is the only attribute
that's implicit without an instance given on the right side, and is also
most common for mixins.   not clear if there are more issues in this
area, however it appears that we need only adjust the order in which we
accommodate grabbing the annotations in order to affect how dataclasses
sees the class; that is, we have control over ``__annotations__`` here
so dont have to worry about ``cls.__dict__``.

Fixes: #8688
Change-Id: I808c86f23d73aa47cd910ae01c3e07093d469fdc

doc/build/changelog/unreleased_20/8688.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_base.py
test/orm/declarative/test_dc_transforms.py

diff --git a/doc/build/changelog/unreleased_20/8688.rst b/doc/build/changelog/unreleased_20/8688.rst
new file mode 100644 (file)
index 0000000..7ae4d2b
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 8688
+
+    Fixed issue with new dataclass mapping feature where arguments passed to
+    the dataclasses API could sometimes be mis-ordered when dealing with mixins
+    that override :func:`_orm.mapped_column` declarations, leading to
+    initializer problems.
index eed04025dc7f94beaa67d5486584b2ed84da8aa8..ef2c2f3c92d90c7fa52f30ae1ad83d2c009960c8 100644 (file)
@@ -420,7 +420,7 @@ class _ClassScanMapperConfig(_MapperConfig):
 
     registry: _RegistryType
     clsdict_view: _ClassDict
-    collected_annotations: Dict[str, Tuple[Any, Any, Any, bool]]
+    collected_annotations: Dict[str, Tuple[Any, Any, Any, bool, Any]]
     collected_attributes: Dict[str, Any]
     local_table: Optional[FromClause]
     persist_selectable: Optional[FromClause]
@@ -831,7 +831,6 @@ class _ClassScanMapperConfig(_MapperConfig):
                     # acting like that for now.
 
                     if isinstance(obj, (Column, MappedColumn)):
-                        self._collect_annotation(name, annotation, True, obj)
                         # already copied columns to the mapped class.
                         continue
                     elif isinstance(obj, MapperProperty):
@@ -1000,6 +999,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                     mapped_container,
                     mapped_anno,
                     is_dc,
+                    attr_value,
                 ) in self.collected_annotations.items()
             )
         ]
@@ -1018,6 +1018,7 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         for k, v in defaults.items():
             setattr(self.cls, k, v)
+
         self.cls.__annotations__ = annotations
 
         self._assert_dc_arguments(dataclass_setup_arguments)
@@ -1056,6 +1057,10 @@ class _ClassScanMapperConfig(_MapperConfig):
         expect_mapped: Optional[bool],
         attr_value: Any,
     ) -> Any:
+
+        if name in self.collected_annotations:
+            return self.collected_annotations[name][4]
+
         if raw_annotation is None:
             return attr_value
 
@@ -1105,6 +1110,7 @@ class _ClassScanMapperConfig(_MapperConfig):
             mapped_container,
             extracted_mapped_annotation,
             is_dataclass,
+            attr_value,
         )
         return attr_value
 
@@ -1133,6 +1139,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         # copy mixin columns to the mapped class
 
         for name, obj, annotation, is_dataclass in attributes_for_class():
+
             if (
                 not fixed_table
                 and obj is None
@@ -1146,6 +1153,9 @@ class _ClassScanMapperConfig(_MapperConfig):
                 setattr(cls, name, obj)
 
             elif isinstance(obj, (Column, MappedColumn)):
+
+                obj = self._collect_annotation(name, annotation, True, obj)
+
                 if attribute_is_overridden(name, obj):
                     # if column has been overridden
                     # (like by the InstrumentedAttribute of the
@@ -1176,6 +1186,7 @@ class _ClassScanMapperConfig(_MapperConfig):
 
                     locally_collected_attributes[name] = copy_
                     setattr(cls, name, copy_)
+
         return locally_collected_attributes
 
     def _extract_mappable_attributes(self) -> None:
@@ -1260,8 +1271,9 @@ class _ClassScanMapperConfig(_MapperConfig):
                         mapped_container,
                         extracted_mapped_annotation,
                         is_dataclass,
+                        attr_value,
                     ) = self.collected_annotations.get(
-                        k, (None, None, None, False)
+                        k, (None, None, None, False, None)
                     )
                     value.declarative_scan(
                         self.registry,
index bff9482ec5d542f87389fd9604103276909f6f7c..b467644bf58ea02bb7b0f869999f6c678f13cae0 100644 (file)
@@ -14,6 +14,7 @@ from unittest import mock
 
 from typing_extensions import Annotated
 
+from sqlalchemy import BigInteger
 from sqlalchemy import Column
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
@@ -553,6 +554,31 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
         eq_(fas.args, ["self", "id"])
         eq_(fas.kwonlyargs, ["data"])
 
+    def test_mapped_column_overrides(self, dc_decl_base):
+        """test #8688"""
+
+        class TriggeringMixin:
+            mixin_value: Mapped[int] = mapped_column(BigInteger)
+
+        class NonTriggeringMixin:
+            mixin_value: Mapped[int]
+
+        class Foo(dc_decl_base, TriggeringMixin):
+            __tablename__ = "foo"
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            foo_value: Mapped[float] = mapped_column(default=78)
+
+        class Bar(dc_decl_base, NonTriggeringMixin):
+            __tablename__ = "bar"
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            bar_value: Mapped[float] = mapped_column(default=78)
+
+        f1 = Foo(mixin_value=5)
+        eq_(f1.foo_value, 78)
+
+        b1 = Bar(mixin_value=5)
+        eq_(b1.bar_value, 78)
+
 
 class RelationshipDefaultFactoryTest(fixtures.TestBase):
     def test_list(self, dc_decl_base: Type[MappedAsDataclass]):