]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add new pattern for single inh column override
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 17 Nov 2022 17:03:46 +0000 (12:03 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 30 Nov 2022 23:04:08 +0000 (18:04 -0500)
Added a new parameter :paramref:`_orm.mapped_column.use_existing_column` to
accommodate the use case of a single-table inheritance mapping that uses
the pattern of more than one subclass indicating the same column to take
place on the superclass. This pattern was previously possible by using
:func:`_orm.declared_attr` in conjunction with locating the existing column
in the ``.__table__`` of the superclass, however is now updated to work
with :func:`_orm.mapped_column` as well as with pep-484 typing, in a
simple and succinct way.

Fixes: #8822
Change-Id: I2296a4a775da976c642c86567852cdc792610eaf

doc/build/changelog/unreleased_20/8822.rst [new file with mode: 0644]
doc/build/orm/inheritance.rst
lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/relationships.py
test/ext/declarative/test_reflection.py
test/orm/declarative/test_inheritance.py

diff --git a/doc/build/changelog/unreleased_20/8822.rst b/doc/build/changelog/unreleased_20/8822.rst
new file mode 100644 (file)
index 0000000..c3f062a
--- /dev/null
@@ -0,0 +1,19 @@
+.. change::
+    :tags: feature, orm
+    :tickets: 8822
+
+    Added a new parameter :paramref:`_orm.mapped_column.use_existing_column` to
+    accommodate the use case of a single-table inheritance mapping that uses
+    the pattern of more than one subclass indicating the same column to take
+    place on the superclass. This pattern was previously possible by using
+    :func:`_orm.declared_attr` in conjunction with locating the existing column
+    in the ``.__table__`` of the superclass, however is now updated to work
+    with :func:`_orm.mapped_column` as well as with pep-484 typing, in a
+    simple and succinct way.
+
+    .. seealso::
+
+       :ref:`orm_inheritance_column_conflicts`
+
+
+
index 2552b643728ecd0e4780c77c98b899cf9130943f..7d7213db719362f161f8f4f31e835c9194225bb6 100644 (file)
@@ -325,8 +325,8 @@ their own.
 
 .. _orm_inheritance_column_conflicts:
 
-Resolving Column Conflicts
-+++++++++++++++++++++++++++
+Resolving Column Conflicts with ``use_existing_column``
++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 
 Note in the previous section that the ``manager_name`` and ``engineer_info`` columns
 are "moved up" to be applied to ``Employee.__table__``, as a result of their
@@ -366,18 +366,20 @@ will result in an error:
 
 .. sourcecode:: text
 
-    sqlalchemy.exc.ArgumentError: Column 'start_date' on class
-    <class '__main__.Manager'> conflicts with existing
-    column 'employee.start_date'
+
+    sqlalchemy.exc.ArgumentError: Column 'start_date' on class Manager conflicts
+    with existing column 'employee.start_date'.  If using Declarative,
+    consider using the use_existing_column parameter of mapped_column() to
+    resolve conflicts.
 
 The above scenario presents an ambiguity to the Declarative mapping system that
-may be resolved by using
-:class:`.declared_attr` to define the :class:`_schema.Column` conditionally,
-taking care to return the **existing column** via the parent ``__table__``
-if it already exists::
+may be resolved by using the :paramref:`_orm.mapped_column.use_existing_column`
+parameter on :func:`_orm.mapped_column`, which instructs :func:`_orm.mapped_column`
+to look on the inheriting superclass present and use the column that's already
+mapped, if already present, else to map a new column::
+
 
     from sqlalchemy import DateTime
-    from sqlalchemy.orm import declared_attr
 
 
     class Employee(Base):
@@ -397,15 +399,7 @@ if it already exists::
             "polymorphic_identity": "engineer",
         }
 
-        @declared_attr
-        def start_date(cls) -> Mapped[datetime]:
-            "Start date column, if not present already."
-
-            # the DateTime type is required in the mapped_column
-            # at the moment when used inside of a @declared_attr
-            return Employee.__table__.c.get(
-                "start_date", mapped_column(DateTime)  # type: ignore
-            )
+        start_date: Mapped[datetime] = mapped_column(use_existing_column=True)
 
 
     class Manager(Employee):
@@ -413,20 +407,25 @@ if it already exists::
             "polymorphic_identity": "manager",
         }
 
-        @declared_attr
-        def start_date(cls) -> Mapped[datetime]:
-            "Start date column, if not present already."
-
-            # the DateTime type is required in the mapped_column
-            # at the moment when used inside of a @declared_attr
-            return Employee.__table__.c.get(
-                "start_date", mapped_column(DateTime)  # type: ignore
-            )
+        start_date: Mapped[datetime] = mapped_column(use_existing_column=True)
 
 Above, when ``Manager`` is mapped, the ``start_date`` column is
-already present on the ``Employee`` class; by returning the existing
-:class:`_schema.Column` object, the declarative system recognizes that this
-is the same column to be mapped to the two different subclasses separately.
+already present on the ``Employee`` class, having been provided by the
+``Engineer`` mapping already.   The :paramref:`_orm.mapped_column.use_existing_column`
+parameter indicates to :func:`_orm.mapped_column` that it should look for the
+requested :class:`_schema.Column` on the mapped :class:`.Table` for
+``Employee`` first, and if present, maintain that existing mapping.  If not
+present, :func:`_orm.mapped_column` will map the column normally, adding it
+as one of the columns in the :class:`.Table` referred towards by the
+``Employee`` superclass.
+
+
+.. versionadded:: 2.0.0b4 - Added :paramref:`_orm.mapped_column.use_existing_column`,
+   which provides a 2.0-compatible means of mapping a column on an inheriting
+   subclass conditionally.  The previous approach which combines
+   :class:`.declared_attr` with a lookup on the parent ``.__table__``
+   continues to function as well, but lacks :pep:`484` typing support.
+
 
 A similar concept can be used with mixin classes (see :ref:`orm_mixins_toplevel`)
 to define a particular series of columns and/or other mapped attributes
@@ -445,11 +444,7 @@ from a reusable mixin class::
 
 
     class HasStartDate:
-        @declared_attr
-        def start_date(cls) -> Mapped[datetime]:
-            return cls.__table__.c.get(
-                "start_date", mapped_column(DateTime)  # type: ignore
-            )
+        start_date: Mapped[datetime] = mapped_column(use_existing_column=True)
 
 
     class Engineer(HasStartDate, Employee):
index 2450d1e8364f284a4d183565100ca9f1ef58def8..2e8babd3d2a746c17d2913d81f12a0b77263a858 100644 (file)
@@ -112,6 +112,7 @@ def mapped_column(
     deferred: Union[_NoArg, bool] = _NoArg.NO_ARG,
     deferred_group: Optional[str] = None,
     deferred_raiseload: bool = False,
+    use_existing_column: bool = False,
     name: Optional[str] = None,
     type_: Optional[_TypeEngineArgument[Any]] = None,
     autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto",
@@ -208,6 +209,18 @@ def mapped_column(
 
         :ref:`orm_queryguide_deferred_raiseload`
 
+    :param use_existing_column: if True, will attempt to locate the given
+     column name on an inherited superclass (typically single inheriting
+     superclass), and if present, will not produce a new column, mapping
+     to the superclass column as though it were omitted from this class.
+     This is used for mixins that add new columns to an inherited superclass.
+
+     .. seealso::
+
+        :ref:`orm_inheritance_column_conflicts`
+
+     .. versionadded:: 2.0.0b4
+
     :param default: Passed directly to the
      :paramref:`_schema.Column.default` parameter if the
      :paramref:`_orm.mapped_column.insert_default` parameter is not present.
@@ -275,6 +288,7 @@ def mapped_column(
         primary_key=primary_key,
         server_default=server_default,
         server_onupdate=server_onupdate,
+        use_existing_column=use_existing_column,
         quote=quote,
         comment=comment,
         system=system,
index 1e716e687b2913fc6172d4a9063d986e71d3d227..797828377e0dd6772da641a5b0a28e909d1e5a32 100644 (file)
@@ -128,6 +128,28 @@ def _declared_mapping_info(
         return None
 
 
+def _is_supercls_for_inherits(cls: Type[Any]) -> bool:
+    """return True if this class will be used as a superclass to set in
+    'inherits'.
+
+    This includes deferred mapper configs that aren't mapped yet, however does
+    not include classes with _sa_decl_prepare_nocascade (e.g.
+    ``AbstractConcreteBase``); these concrete-only classes are not set up as
+    "inherits" until after mappers are configured using
+    mapper._set_concrete_base()
+
+    """
+    if _DeferredMapperConfig.has_cls(cls):
+        return not _get_immediate_cls_attr(
+            cls, "_sa_decl_prepare_nocascade", strict=True
+        )
+    # regular mapping
+    elif _is_mapped_class(cls):
+        return True
+    else:
+        return False
+
+
 def _resolve_for_abstract_or_classical(cls: Type[Any]) -> Optional[Type[Any]]:
     if cls is object:
         return None
@@ -380,11 +402,8 @@ class _ImperativeMapperConfig(_MapperConfig):
                 c = _resolve_for_abstract_or_classical(base_)
                 if c is None:
                     continue
-                if _declared_mapping_info(
-                    c
-                ) is not None and not _get_immediate_cls_attr(
-                    c, "_sa_decl_prepare_nocascade", strict=True
-                ):
+
+                if _is_supercls_for_inherits(c) and c not in inherits_search:
                     inherits_search.append(c)
 
             if inherits_search:
@@ -430,6 +449,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         "allow_unmapped_annotations",
     )
 
+    is_deferred = False
     registry: _RegistryType
     clsdict_view: _ClassDict
     collected_annotations: Dict[str, _CollectedAnnotation]
@@ -532,13 +552,15 @@ class _ClassScanMapperConfig(_MapperConfig):
                 self.classname, self.cls, registry._class_registry
             )
 
+            self._setup_inheriting_mapper(mapper_kw)
+
             self._extract_mappable_attributes()
 
             self._extract_declared_columns()
 
             self._setup_table(table)
 
-            self._setup_inheritance(mapper_kw)
+            self._setup_inheriting_columns(mapper_kw)
 
             self._early_mapping(mapper_kw)
 
@@ -739,13 +761,7 @@ class _ClassScanMapperConfig(_MapperConfig):
             # need to do this all the way up the hierarchy first
             # (see #8190)
 
-            class_mapped = (
-                base is not cls
-                and _declared_mapping_info(base) is not None
-                and not _get_immediate_cls_attr(
-                    base, "_sa_decl_prepare_nocascade", strict=True
-                )
-            )
+            class_mapped = base is not cls and _is_supercls_for_inherits(base)
 
             local_attributes_for_class = self._cls_attr_resolver(base)
 
@@ -1358,6 +1374,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                     if mapped_container is not None or annotation is None:
                         try:
                             value.declarative_scan(
+                                self,
                                 self.registry,
                                 cls,
                                 originating_module,
@@ -1558,11 +1575,8 @@ class _ClassScanMapperConfig(_MapperConfig):
         else:
             return manager.registry.metadata
 
-    def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None:
-        table = self.local_table
+    def _setup_inheriting_mapper(self, mapper_kw: _MapperKwArgs) -> None:
         cls = self.cls
-        table_args = self.table_args
-        declared_columns = self.declared_columns
 
         inherits = mapper_kw.get("inherits", None)
 
@@ -1574,13 +1588,9 @@ class _ClassScanMapperConfig(_MapperConfig):
                 c = _resolve_for_abstract_or_classical(base_)
                 if c is None:
                     continue
-                if _declared_mapping_info(
-                    c
-                ) is not None and not _get_immediate_cls_attr(
-                    c, "_sa_decl_prepare_nocascade", strict=True
-                ):
-                    if c not in inherits_search:
-                        inherits_search.append(c)
+
+                if _is_supercls_for_inherits(c) and c not in inherits_search:
+                    inherits_search.append(c)
 
             if inherits_search:
                 if len(inherits_search) > 1:
@@ -1594,6 +1604,12 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         self.inherits = inherits
 
+    def _setup_inheriting_columns(self, mapper_kw: _MapperKwArgs) -> None:
+        table = self.local_table
+        cls = self.cls
+        table_args = self.table_args
+        declared_columns = self.declared_columns
+
         if (
             table is None
             and self.inherits is None
@@ -1636,9 +1652,12 @@ class _ClassScanMapperConfig(_MapperConfig):
                         if inherited_table.c[col.name] is col:
                             continue
                         raise exc.ArgumentError(
-                            "Column '%s' on class %s conflicts with "
-                            "existing column '%s'"
-                            % (col, cls, inherited_table.c[col.name])
+                            f"Column '{col}' on class {cls.__name__} "
+                            f"conflicts with existing column "
+                            f"'{inherited_table.c[col.name]}'.  If using "
+                            f"Declarative, consider using the "
+                            "use_existing_column parameter of mapped_column() "
+                            "to resolve conflicts."
                         )
                     if col.primary_key:
                         raise exc.ArgumentError(
@@ -1695,14 +1714,15 @@ class _ClassScanMapperConfig(_MapperConfig):
             mapper_args["inherits"] = self.inherits
 
         if self.inherits and not mapper_args.get("concrete", False):
+            # note the superclass is expected to have a Mapper assigned and
+            # not be a deferred config, as this is called within map()
+            inherited_mapper = class_mapper(self.inherits, False)
+            inherited_table = inherited_mapper.local_table
+
             # single or joined inheritance
             # exclude any cols on the inherited table which are
             # not mapped on the parent class, to avoid
             # mapping columns specific to sibling/nephew classes
-            inherited_mapper = _declared_mapping_info(self.inherits)
-            assert isinstance(inherited_mapper, Mapper)
-            inherited_table = inherited_mapper.local_table
-
             if "exclude_properties" not in mapper_args:
                 mapper_args["exclude_properties"] = exclude_properties = {
                     c.key
@@ -1768,6 +1788,8 @@ def _as_dc_declaredattr(
 class _DeferredMapperConfig(_ClassScanMapperConfig):
     _cls: weakref.ref[Type[Any]]
 
+    is_deferred = True
+
     _configs: util.OrderedDict[
         weakref.ref[Type[Any]], _DeferredMapperConfig
     ] = util.OrderedDict()
index 55c7e9290b4809e2c48d3be09b82ade899ad57d3..56d6b2f6fe6e032c75feda772b7f349eb799d2a9 100644 (file)
@@ -64,6 +64,7 @@ if typing.TYPE_CHECKING:
     from .attributes import InstrumentedAttribute
     from .attributes import QueryableAttribute
     from .context import ORMCompileState
+    from .decl_base import _ClassScanMapperConfig
     from .mapper import Mapper
     from .properties import ColumnProperty
     from .properties import MappedColumn
@@ -332,6 +333,7 @@ class CompositeProperty(
     @util.preload_module("sqlalchemy.orm.properties")
     def declarative_scan(
         self,
+        decl_scan: _ClassScanMapperConfig,
         registry: _RegistryType,
         cls: Type[Any],
         originating_module: Optional[str],
index 3d2f9708fc5c852f04f839e448e5edcb4f3a8fe9..18083241b224cf276bac952b70290f28c9a1bfda 100644 (file)
@@ -84,6 +84,7 @@ if typing.TYPE_CHECKING:
     from .context import ORMCompileState
     from .context import QueryContext
     from .decl_api import RegistryType
+    from .decl_base import _ClassScanMapperConfig
     from .loading import _PopulatorDict
     from .mapper import Mapper
     from .path_registry import AbstractEntityRegistry
@@ -157,6 +158,7 @@ class _IntrospectsAnnotations:
 
     def declarative_scan(
         self,
+        decl_scan: _ClassScanMapperConfig,
         registry: RegistryType,
         cls: Type[Any],
         originating_module: Optional[str],
index 1a5f0bd71d9724ac98912fcaf57b028ffc983369..e766fd06cd0295465d675d817a4e23157893dd8a 100644 (file)
@@ -28,6 +28,7 @@ from typing import TypeVar
 from . import attributes
 from . import strategy_options
 from .base import _DeclarativeMapped
+from .base import class_mapper
 from .descriptor_props import CompositeProperty
 from .descriptor_props import ConcreteInheritedProperty
 from .descriptor_props import SynonymProperty
@@ -66,6 +67,7 @@ if TYPE_CHECKING:
     from ._typing import _ORMColumnExprArgument
     from ._typing import _RegistryType
     from .base import Mapped
+    from .decl_base import _ClassScanMapperConfig
     from .mapper import Mapper
     from .session import Session
     from .state import _InstallLoaderCallableProto
@@ -192,6 +194,7 @@ class ColumnProperty(
 
     def declarative_scan(
         self,
+        decl_scan: _ClassScanMapperConfig,
         registry: _RegistryType,
         cls: Type[Any],
         originating_module: Optional[str],
@@ -531,6 +534,7 @@ class MappedColumn(
         "deferred_raiseload",
         "_attribute_options",
         "_has_dataclass_arguments",
+        "_use_existing_column",
     )
 
     deferred: bool
@@ -546,6 +550,8 @@ class MappedColumn(
             "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS
         )
 
+        self._use_existing_column = kw.pop("use_existing_column", False)
+
         self._has_dataclass_arguments = False
 
         if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS:
@@ -592,6 +598,7 @@ class MappedColumn(
         new._attribute_options = self._attribute_options
         new._has_insert_default = self._has_insert_default
         new._has_dataclass_arguments = self._has_dataclass_arguments
+        new._use_existing_column = self._use_existing_column
         util.set_creation_order(new)
         return new
 
@@ -635,6 +642,7 @@ class MappedColumn(
 
     def declarative_scan(
         self,
+        decl_scan: _ClassScanMapperConfig,
         registry: _RegistryType,
         cls: Type[Any],
         originating_module: Optional[str],
@@ -645,6 +653,18 @@ class MappedColumn(
         is_dataclass_field: bool,
     ) -> None:
         column = self.column
+
+        if self._use_existing_column and decl_scan.inherits:
+            if decl_scan.is_deferred:
+                raise sa_exc.ArgumentError(
+                    "Can't use use_existing_column with deferred mappers"
+                )
+            supercls_mapper = class_mapper(decl_scan.inherits, False)
+
+            column = self.column = supercls_mapper.local_table.c.get(  # type: ignore # noqa: E501
+                key, column
+            )
+
         if column.key is None:
             column.key = key
         if column.name is None:
index 73d11e8800813fd938aacc986ec0183f1706bf34..4a9bcd711ea141c04dbbf526ce06cb7be1f03dc7 100644 (file)
@@ -101,6 +101,7 @@ if typing.TYPE_CHECKING:
     from .base import Mapped
     from .clsregistry import _class_resolver
     from .clsregistry import _ModNS
+    from .decl_base import _ClassScanMapperConfig
     from .dependency import DependencyProcessor
     from .mapper import Mapper
     from .query import Query
@@ -1723,6 +1724,7 @@ class RelationshipProperty(
 
     def declarative_scan(
         self,
+        decl_scan: _ClassScanMapperConfig,
         registry: _RegistryType,
         cls: Type[Any],
         originating_module: Optional[str],
index e143ad1277e7978cb923320d1b94938f6ca24793..53f518a27f7a0796160b50e56863fbc6275d99d1 100644 (file)
@@ -1,3 +1,6 @@
+from __future__ import annotations
+
+from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
 from sqlalchemy import String
@@ -7,11 +10,14 @@ from sqlalchemy.orm import clear_mappers
 from sqlalchemy.orm import decl_api as decl
 from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import exc as orm_exc
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm.decl_base import _DeferredMapperConfig
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
@@ -33,6 +39,10 @@ class DeclarativeReflectionBase(fixtures.TablesTest):
     def teardown_test(self):
         clear_mappers()
 
+    @testing.fixture
+    def decl_base(self):
+        yield Base
+
 
 class DeferredReflectBase(DeclarativeReflectionBase):
     def teardown_test(self):
@@ -346,8 +356,8 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
             Column("bar_data", String(30)),
         )
 
-    def test_basic(self):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, Base):
+    def test_basic(self, decl_base):
+        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -360,8 +370,8 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
         DeferredReflection.prepare(testing.db)
         self._roundtrip()
 
-    def test_add_subclass_column(self):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, Base):
+    def test_add_subclass_column(self, decl_base):
+        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -375,8 +385,40 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
         DeferredReflection.prepare(testing.db)
         self._roundtrip()
 
-    def test_add_pk_column(self):
-        class Foo(DeferredReflection, fixtures.ComparableEntity, Base):
+    def test_add_subclass_mapped_column(self, decl_base):
+        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+            __tablename__ = "foo"
+            __mapper_args__ = {
+                "polymorphic_on": "type",
+                "polymorphic_identity": "foo",
+            }
+
+        class Bar(Foo):
+            __mapper_args__ = {"polymorphic_identity": "bar"}
+            bar_data: Mapped[str]
+
+        DeferredReflection.prepare(testing.db)
+        self._roundtrip()
+
+    def test_subclass_mapped_column_no_existing(self, decl_base):
+        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+            __tablename__ = "foo"
+            __mapper_args__ = {
+                "polymorphic_on": "type",
+                "polymorphic_identity": "foo",
+            }
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            "Can't use use_existing_column with deferred mappers",
+        ):
+
+            class Bar(Foo):
+                __mapper_args__ = {"polymorphic_identity": "bar"}
+                bar_data: Mapped[str] = mapped_column(use_existing_column=True)
+
+    def test_add_pk_column(self, decl_base):
+        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
             __tablename__ = "foo"
             __mapper_args__ = {
                 "polymorphic_on": "type",
@@ -390,6 +432,21 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
         DeferredReflection.prepare(testing.db)
         self._roundtrip()
 
+    def test_add_pk_mapped_column(self, decl_base):
+        class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base):
+            __tablename__ = "foo"
+            __mapper_args__ = {
+                "polymorphic_on": "type",
+                "polymorphic_identity": "foo",
+            }
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        class Bar(Foo):
+            __mapper_args__ = {"polymorphic_identity": "bar"}
+
+        DeferredReflection.prepare(testing.db)
+        self._roundtrip()
+
 
 class DeferredJoinedInhReflectionTest(DeferredInhReflectBase):
     @classmethod
index 9829f423336c295ecff04b10177c428e71cb3cd1..f3506a3100adc885e715f65e364ccf6124af87ad 100644 (file)
@@ -1,5 +1,6 @@
 import sqlalchemy as sa
 from sqlalchemy import ForeignKey
+from sqlalchemy import Identity
 from sqlalchemy import Integer
 from sqlalchemy import select
 from sqlalchemy import String
@@ -9,7 +10,10 @@ from sqlalchemy.orm import close_all_sessions
 from sqlalchemy.orm import configure_mappers
 from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import deferred
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import Session
 from sqlalchemy.orm import with_polymorphic
 from sqlalchemy.orm.decl_api import registry
 from sqlalchemy.testing import assert_raises
@@ -805,11 +809,10 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
             [Person.__table__.c.name, Person.__table__.c.primary_language],
         )
 
-    @testing.skip_if(
-        lambda: testing.against("oracle"),
-        "Test has an empty insert in it at the moment",
-    )
-    def test_columns_single_inheritance_conflict_resolution(self):
+    @testing.variation("decl_type", ["legacy", "use_existing_column"])
+    def test_columns_single_inheritance_conflict_resolution(
+        self, connection, decl_base, decl_type
+    ):
         """Test that a declared_attr can return the existing column and it will
         be ignored.  this allows conditional columns to be added.
 
@@ -817,18 +820,25 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
 
         """
 
-        class Person(Base):
+        class Person(decl_base):
             __tablename__ = "person"
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, Identity(), primary_key=True)
 
         class Engineer(Person):
 
             """single table inheritance"""
 
-            @declared_attr
-            def target_id(cls):
-                return cls.__table__.c.get(
-                    "target_id", Column(Integer, ForeignKey("other.id"))
+            if decl_type.legacy:
+
+                @declared_attr
+                def target_id(cls):
+                    return cls.__table__.c.get(
+                        "target_id", Column(Integer, ForeignKey("other.id"))
+                    )
+
+            elif decl_type.use_existing_column:
+                target_id: Mapped[int] = mapped_column(
+                    ForeignKey("other.id"), use_existing_column=True
                 )
 
             @declared_attr
@@ -839,19 +849,26 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
 
             """single table inheritance"""
 
-            @declared_attr
-            def target_id(cls):
-                return cls.__table__.c.get(
-                    "target_id", Column(Integer, ForeignKey("other.id"))
+            if decl_type.legacy:
+
+                @declared_attr
+                def target_id(cls):
+                    return cls.__table__.c.get(
+                        "target_id", Column(Integer, ForeignKey("other.id"))
+                    )
+
+            elif decl_type.use_existing_column:
+                target_id: Mapped[int] = mapped_column(
+                    ForeignKey("other.id"), use_existing_column=True
                 )
 
             @declared_attr
             def target(cls):
                 return relationship("Other")
 
-        class Other(Base):
+        class Other(decl_base):
             __tablename__ = "other"
-            id = Column(Integer, primary_key=True)
+            id = Column(Integer, Identity(), primary_key=True)
 
         is_(
             Engineer.target_id.property.columns[0],
@@ -861,22 +878,25 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
             Manager.target_id.property.columns[0], Person.__table__.c.target_id
         )
         # do a brief round trip on this
-        Base.metadata.create_all(testing.db)
-        session = fixture_session()
-        o1, o2 = Other(), Other()
-        session.add_all(
-            [Engineer(target=o1), Manager(target=o2), Manager(target=o1)]
-        )
-        session.commit()
-        eq_(session.query(Engineer).first().target, o1)
+        decl_base.metadata.create_all(connection)
+        with Session(connection) as session:
+            o1, o2 = Other(), Other()
+            session.add_all(
+                [Engineer(target=o1), Manager(target=o2), Manager(target=o1)]
+            )
+            session.commit()
+            eq_(session.query(Engineer).first().target, o1)
 
-    def test_columns_single_inheritance_conflict_resolution_pk(self):
+    @testing.variation("decl_type", ["legacy", "use_existing_column"])
+    def test_columns_single_inheritance_conflict_resolution_pk(
+        self, decl_base, decl_type
+    ):
         """Test #2472 in terms of a primary key column.  This is
         #4352.
 
         """
 
-        class Person(Base):
+        class Person(decl_base):
             __tablename__ = "person"
             id = Column(Integer, primary_key=True)
 
@@ -886,20 +906,34 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
 
             """single table inheritance"""
 
-            @declared_attr
-            def target_id(cls):
-                return cls.__table__.c.get(
-                    "target_id", Column(Integer, primary_key=True)
+            if decl_type.legacy:
+
+                @declared_attr
+                def target_id(cls):
+                    return cls.__table__.c.get(
+                        "target_id", Column(Integer, primary_key=True)
+                    )
+
+            elif decl_type.use_existing_column:
+                target_id: Mapped[int] = mapped_column(
+                    primary_key=True, use_existing_column=True
                 )
 
         class Manager(Person):
 
             """single table inheritance"""
 
-            @declared_attr
-            def target_id(cls):
-                return cls.__table__.c.get(
-                    "target_id", Column(Integer, primary_key=True)
+            if decl_type.legacy:
+
+                @declared_attr
+                def target_id(cls):
+                    return cls.__table__.c.get(
+                        "target_id", Column(Integer, primary_key=True)
+                    )
+
+            elif decl_type.use_existing_column:
+                target_id: Mapped[int] = mapped_column(
+                    primary_key=True, use_existing_column=True
                 )
 
         is_(
@@ -910,20 +944,30 @@ class DeclarativeInheritanceTest(DeclarativeTestBase):
             Manager.target_id.property.columns[0], Person.__table__.c.target_id
         )
 
-    def test_columns_single_inheritance_cascading_resolution_pk(self):
+    @testing.variation("decl_type", ["legacy", "use_existing_column"])
+    def test_columns_single_inheritance_cascading_resolution_pk(
+        self, decl_type
+    ):
         """An additional test for #4352 in terms of the requested use case."""
 
         class TestBase(Base):
             __abstract__ = True
 
-            @declared_attr.cascading
-            def id(cls):
-                col_val = None
-                if TestBase not in cls.__bases__:
-                    col_val = cls.__table__.c.get("id")
-                if col_val is None:
-                    col_val = Column(Integer, primary_key=True)
-                return col_val
+            if decl_type.legacy:
+
+                @declared_attr.cascading
+                def id(cls):  # noqa: A001
+                    col_val = None
+                    if TestBase not in cls.__bases__:
+                        col_val = cls.__table__.c.get("id")
+                    if col_val is None:
+                        col_val = Column(Integer, primary_key=True)
+                    return col_val
+
+            elif decl_type.use_existing_column:
+                id: Mapped[int] = mapped_column(  # noqa: A001
+                    primary_key=True, use_existing_column=True
+                )
 
         class Person(TestBase):
             """single table base class"""