]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix use_existing_column with Annotated mapped_column in polymorphic inheritance
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Aug 2025 18:05:49 +0000 (14:05 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Aug 2025 18:58:47 +0000 (14:58 -0400)
Fixed issue where :paramref:`_orm.mapped_column.use_existing_column`
parameter in :func:`_orm.mapped_column` would not work when the
:func:`_orm.mapped_column` is used inside of an ``Annotated`` type alias in
polymorphic inheritance scenarios. The parameter is now properly recognized
and processed during declarative mapping configuration.

Fixes: #12787
Change-Id: I0505df3f3714434e98052c4488f6b1b1d2b1f755

doc/build/changelog/unreleased_20/12787.rst [new file with mode: 0644]
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/properties.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/12787.rst b/doc/build/changelog/unreleased_20/12787.rst
new file mode 100644 (file)
index 0000000..44c36fe
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 12787
+
+    Fixed issue where :paramref:`_orm.mapped_column.use_existing_column`
+    parameter in :func:`_orm.mapped_column` would not work when the
+    :func:`_orm.mapped_column` is used inside of an ``Annotated`` type alias in
+    polymorphic inheritance scenarios. The parameter is now properly recognized
+    and processed during declarative mapping configuration.
index 287d065b0b9972c9df701c7dd34eae03618b659a..62cb5afc7c0eb0831db423521b6e8e852bf375b1 100644 (file)
@@ -402,7 +402,9 @@ class CompositeProperty(
             self.composite_class = argument
 
         if is_dataclass(self.composite_class):
-            self._setup_for_dataclass(registry, cls, originating_module, key)
+            self._setup_for_dataclass(
+                decl_scan, registry, cls, originating_module, key
+            )
         else:
             for attr in self.attrs:
                 if (
@@ -446,6 +448,7 @@ class CompositeProperty(
     @util.preload_module("sqlalchemy.orm.decl_base")
     def _setup_for_dataclass(
         self,
+        decl_scan: _ClassScanMapperConfig,
         registry: _RegistryType,
         cls: Type[Any],
         originating_module: Optional[str],
@@ -473,6 +476,7 @@ class CompositeProperty(
 
             if isinstance(attr, MappedColumn):
                 attr.declarative_scan_for_composite(
+                    decl_scan,
                     registry,
                     cls,
                     originating_module,
index 3afb6e140a0d5477c2f6c831aaabf56f828e238d..bc0c8fdda3238ab114d4c01db5c8ede21ba6eb17 100644 (file)
@@ -682,20 +682,12 @@ class MappedColumn(
         # Column will be merged into it in _init_column_for_annotation().
         return MappedColumn()
 
-    def declarative_scan(
+    def _adjust_for_existing_column(
         self,
         decl_scan: _ClassScanMapperConfig,
-        registry: _RegistryType,
-        cls: Type[Any],
-        originating_module: Optional[str],
         key: str,
-        mapped_container: Optional[Type[Mapped[Any]]],
-        annotation: Optional[_AnnotationScanType],
-        extracted_mapped_annotation: Optional[_AnnotationScanType],
-        is_dataclass_field: bool,
-    ) -> None:
-        column = self.column
-
+        given_column: Column[_T],
+    ) -> Column[_T]:
         if (
             self._use_existing_column
             and decl_scan.inherits
@@ -707,10 +699,31 @@ class MappedColumn(
                 )
             supercls_mapper = class_mapper(decl_scan.inherits, False)
 
-            colname = column.name if column.name is not None else key
-            column = self.column = supercls_mapper.local_table.c.get(  # type: ignore[assignment] # noqa: E501
-                colname, column
+            colname = (
+                given_column.name if given_column.name is not None else key
             )
+            given_column = supercls_mapper.local_table.c.get(  # type: ignore[assignment] # noqa: E501
+                colname, given_column
+            )
+        return given_column
+
+    def declarative_scan(
+        self,
+        decl_scan: _ClassScanMapperConfig,
+        registry: _RegistryType,
+        cls: Type[Any],
+        originating_module: Optional[str],
+        key: str,
+        mapped_container: Optional[Type[Mapped[Any]]],
+        annotation: Optional[_AnnotationScanType],
+        extracted_mapped_annotation: Optional[_AnnotationScanType],
+        is_dataclass_field: bool,
+    ) -> None:
+        column = self.column
+
+        column = self.column = self._adjust_for_existing_column(
+            decl_scan, key, self.column
+        )
 
         if column.key is None:
             column.key = key
@@ -727,6 +740,8 @@ class MappedColumn(
 
         self._init_column_for_annotation(
             cls,
+            decl_scan,
+            key,
             registry,
             extracted_mapped_annotation,
             originating_module,
@@ -735,6 +750,7 @@ class MappedColumn(
     @util.preload_module("sqlalchemy.orm.decl_base")
     def declarative_scan_for_composite(
         self,
+        decl_scan: _ClassScanMapperConfig,
         registry: _RegistryType,
         cls: Type[Any],
         originating_module: Optional[str],
@@ -745,12 +761,14 @@ class MappedColumn(
         decl_base = util.preloaded.orm_decl_base
         decl_base._undefer_column_name(param_name, self.column)
         self._init_column_for_annotation(
-            cls, registry, param_annotation, originating_module
+            cls, decl_scan, key, registry, param_annotation, originating_module
         )
 
     def _init_column_for_annotation(
         self,
         cls: Type[Any],
+        decl_scan: _ClassScanMapperConfig,
+        key: str,
         registry: _RegistryType,
         argument: _AnnotationScanType,
         originating_module: Optional[str],
@@ -798,6 +816,10 @@ class MappedColumn(
 
         if use_args_from is not None:
 
+            self.column = use_args_from._adjust_for_existing_column(
+                decl_scan, key, self.column
+            )
+
             if (
                 self._has_insert_default
                 or self._attribute_options.dataclasses_default
index d55f9f80b56ee75043ae3c1b7380fd59606de12f..ac343c2315c4efcf26f7d480a20dac2d56876220 100644 (file)
@@ -2130,6 +2130,40 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             is_(getattr(Element.__table__.c.data, paramname), override_value)
 
+    def test_use_existing_column_from_pep_593(self, decl_base):
+        """test #12787"""
+
+        global Label
+        Label = Annotated[
+            str, mapped_column(String(20), use_existing_column=True)
+        ]
+
+        class A(decl_base):
+            __tablename__ = "table_a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            discriminator: Mapped[int]
+
+            __mapper_args__ = {
+                "polymorphic_on": "discriminator",
+                "polymorphic_abstract": True,
+            }
+
+        class A_1(A):
+            label: Mapped[Label]
+
+            __mapper_args__ = {"polymorphic_identity": 1}
+
+        class A_2(A):
+            label: Mapped[Label]
+
+            __mapper_args__ = {"polymorphic_identity": 2}
+
+        is_(A_1.label.property.columns[0], A_2.label.property.columns[0])
+
+        eq_(A_1.label.property.columns[0].table, A.__table__)
+        eq_(A_2.label.property.columns[0].table, A.__table__)
+
     @testing.variation(
         "union",
         [
index c8c8fec9cd9171c00fb1b4d560a4624e0211e73e..72ba534f3593174254c0f7994e6717fcdb71c6da 100644 (file)
@@ -2121,6 +2121,40 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         else:
             is_(getattr(Element.__table__.c.data, paramname), override_value)
 
+    def test_use_existing_column_from_pep_593(self, decl_base):
+        """test #12787"""
+
+        # anno only: global Label
+        Label = Annotated[
+            str, mapped_column(String(20), use_existing_column=True)
+        ]
+
+        class A(decl_base):
+            __tablename__ = "table_a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            discriminator: Mapped[int]
+
+            __mapper_args__ = {
+                "polymorphic_on": "discriminator",
+                "polymorphic_abstract": True,
+            }
+
+        class A_1(A):
+            label: Mapped[Label]
+
+            __mapper_args__ = {"polymorphic_identity": 1}
+
+        class A_2(A):
+            label: Mapped[Label]
+
+            __mapper_args__ = {"polymorphic_identity": 2}
+
+        is_(A_1.label.property.columns[0], A_2.label.property.columns[0])
+
+        eq_(A_1.label.property.columns[0].table, A.__table__)
+        eq_(A_2.label.property.columns[0].table, A.__table__)
+
     @testing.variation(
         "union",
         [