From de68627dd1ba9c2dd44bb3d11be1a3945b285205 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 17 Nov 2022 12:03:46 -0500 Subject: [PATCH] add new pattern for single inh column override Added a new parameter :paramref:`_orm.mapped_column.use_existing_column` to accommodate the use case of a single-table inheritance mapping that uses the pattern of more than one subclass indicating the same column to take place on the superclass. This pattern was previously possible by using :func:`_orm.declared_attr` in conjunction with locating the existing column in the ``.__table__`` of the superclass, however is now updated to work with :func:`_orm.mapped_column` as well as with pep-484 typing, in a simple and succinct way. Fixes: #8822 Change-Id: I2296a4a775da976c642c86567852cdc792610eaf --- doc/build/changelog/unreleased_20/8822.rst | 19 +++ doc/build/orm/inheritance.rst | 67 +++++------ lib/sqlalchemy/orm/_orm_constructors.py | 14 +++ lib/sqlalchemy/orm/decl_base.py | 84 ++++++++----- lib/sqlalchemy/orm/descriptor_props.py | 2 + lib/sqlalchemy/orm/interfaces.py | 2 + lib/sqlalchemy/orm/properties.py | 20 ++++ lib/sqlalchemy/orm/relationships.py | 2 + test/ext/declarative/test_reflection.py | 69 ++++++++++- test/orm/declarative/test_inheritance.py | 132 ++++++++++++++------- 10 files changed, 294 insertions(+), 117 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/8822.rst diff --git a/doc/build/changelog/unreleased_20/8822.rst b/doc/build/changelog/unreleased_20/8822.rst new file mode 100644 index 0000000000..c3f062ac94 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8822.rst @@ -0,0 +1,19 @@ +.. change:: + :tags: feature, orm + :tickets: 8822 + + Added a new parameter :paramref:`_orm.mapped_column.use_existing_column` to + accommodate the use case of a single-table inheritance mapping that uses + the pattern of more than one subclass indicating the same column to take + place on the superclass. This pattern was previously possible by using + :func:`_orm.declared_attr` in conjunction with locating the existing column + in the ``.__table__`` of the superclass, however is now updated to work + with :func:`_orm.mapped_column` as well as with pep-484 typing, in a + simple and succinct way. + + .. seealso:: + + :ref:`orm_inheritance_column_conflicts` + + + diff --git a/doc/build/orm/inheritance.rst b/doc/build/orm/inheritance.rst index 2552b64372..7d7213db71 100644 --- a/doc/build/orm/inheritance.rst +++ b/doc/build/orm/inheritance.rst @@ -325,8 +325,8 @@ their own. .. _orm_inheritance_column_conflicts: -Resolving Column Conflicts -+++++++++++++++++++++++++++ +Resolving Column Conflicts with ``use_existing_column`` ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ Note in the previous section that the ``manager_name`` and ``engineer_info`` columns are "moved up" to be applied to ``Employee.__table__``, as a result of their @@ -366,18 +366,20 @@ will result in an error: .. sourcecode:: text - sqlalchemy.exc.ArgumentError: Column 'start_date' on class - conflicts with existing - column 'employee.start_date' + + sqlalchemy.exc.ArgumentError: Column 'start_date' on class Manager conflicts + with existing column 'employee.start_date'. If using Declarative, + consider using the use_existing_column parameter of mapped_column() to + resolve conflicts. The above scenario presents an ambiguity to the Declarative mapping system that -may be resolved by using -:class:`.declared_attr` to define the :class:`_schema.Column` conditionally, -taking care to return the **existing column** via the parent ``__table__`` -if it already exists:: +may be resolved by using the :paramref:`_orm.mapped_column.use_existing_column` +parameter on :func:`_orm.mapped_column`, which instructs :func:`_orm.mapped_column` +to look on the inheriting superclass present and use the column that's already +mapped, if already present, else to map a new column:: + from sqlalchemy import DateTime - from sqlalchemy.orm import declared_attr class Employee(Base): @@ -397,15 +399,7 @@ if it already exists:: "polymorphic_identity": "engineer", } - @declared_attr - def start_date(cls) -> Mapped[datetime]: - "Start date column, if not present already." - - # the DateTime type is required in the mapped_column - # at the moment when used inside of a @declared_attr - return Employee.__table__.c.get( - "start_date", mapped_column(DateTime) # type: ignore - ) + start_date: Mapped[datetime] = mapped_column(use_existing_column=True) class Manager(Employee): @@ -413,20 +407,25 @@ if it already exists:: "polymorphic_identity": "manager", } - @declared_attr - def start_date(cls) -> Mapped[datetime]: - "Start date column, if not present already." - - # the DateTime type is required in the mapped_column - # at the moment when used inside of a @declared_attr - return Employee.__table__.c.get( - "start_date", mapped_column(DateTime) # type: ignore - ) + start_date: Mapped[datetime] = mapped_column(use_existing_column=True) Above, when ``Manager`` is mapped, the ``start_date`` column is -already present on the ``Employee`` class; by returning the existing -:class:`_schema.Column` object, the declarative system recognizes that this -is the same column to be mapped to the two different subclasses separately. +already present on the ``Employee`` class, having been provided by the +``Engineer`` mapping already. The :paramref:`_orm.mapped_column.use_existing_column` +parameter indicates to :func:`_orm.mapped_column` that it should look for the +requested :class:`_schema.Column` on the mapped :class:`.Table` for +``Employee`` first, and if present, maintain that existing mapping. If not +present, :func:`_orm.mapped_column` will map the column normally, adding it +as one of the columns in the :class:`.Table` referred towards by the +``Employee`` superclass. + + +.. versionadded:: 2.0.0b4 - Added :paramref:`_orm.mapped_column.use_existing_column`, + which provides a 2.0-compatible means of mapping a column on an inheriting + subclass conditionally. The previous approach which combines + :class:`.declared_attr` with a lookup on the parent ``.__table__`` + continues to function as well, but lacks :pep:`484` typing support. + A similar concept can be used with mixin classes (see :ref:`orm_mixins_toplevel`) to define a particular series of columns and/or other mapped attributes @@ -445,11 +444,7 @@ from a reusable mixin class:: class HasStartDate: - @declared_attr - def start_date(cls) -> Mapped[datetime]: - return cls.__table__.c.get( - "start_date", mapped_column(DateTime) # type: ignore - ) + start_date: Mapped[datetime] = mapped_column(use_existing_column=True) class Engineer(HasStartDate, Employee): diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 2450d1e836..2e8babd3d2 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -112,6 +112,7 @@ def mapped_column( deferred: Union[_NoArg, bool] = _NoArg.NO_ARG, deferred_group: Optional[str] = None, deferred_raiseload: bool = False, + use_existing_column: bool = False, name: Optional[str] = None, type_: Optional[_TypeEngineArgument[Any]] = None, autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", @@ -208,6 +209,18 @@ def mapped_column( :ref:`orm_queryguide_deferred_raiseload` + :param use_existing_column: if True, will attempt to locate the given + column name on an inherited superclass (typically single inheriting + superclass), and if present, will not produce a new column, mapping + to the superclass column as though it were omitted from this class. + This is used for mixins that add new columns to an inherited superclass. + + .. seealso:: + + :ref:`orm_inheritance_column_conflicts` + + .. versionadded:: 2.0.0b4 + :param default: Passed directly to the :paramref:`_schema.Column.default` parameter if the :paramref:`_orm.mapped_column.insert_default` parameter is not present. @@ -275,6 +288,7 @@ def mapped_column( primary_key=primary_key, server_default=server_default, server_onupdate=server_onupdate, + use_existing_column=use_existing_column, quote=quote, comment=comment, system=system, diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 1e716e687b..797828377e 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -128,6 +128,28 @@ def _declared_mapping_info( return None +def _is_supercls_for_inherits(cls: Type[Any]) -> bool: + """return True if this class will be used as a superclass to set in + 'inherits'. + + This includes deferred mapper configs that aren't mapped yet, however does + not include classes with _sa_decl_prepare_nocascade (e.g. + ``AbstractConcreteBase``); these concrete-only classes are not set up as + "inherits" until after mappers are configured using + mapper._set_concrete_base() + + """ + if _DeferredMapperConfig.has_cls(cls): + return not _get_immediate_cls_attr( + cls, "_sa_decl_prepare_nocascade", strict=True + ) + # regular mapping + elif _is_mapped_class(cls): + return True + else: + return False + + def _resolve_for_abstract_or_classical(cls: Type[Any]) -> Optional[Type[Any]]: if cls is object: return None @@ -380,11 +402,8 @@ class _ImperativeMapperConfig(_MapperConfig): c = _resolve_for_abstract_or_classical(base_) if c is None: continue - if _declared_mapping_info( - c - ) is not None and not _get_immediate_cls_attr( - c, "_sa_decl_prepare_nocascade", strict=True - ): + + if _is_supercls_for_inherits(c) and c not in inherits_search: inherits_search.append(c) if inherits_search: @@ -430,6 +449,7 @@ class _ClassScanMapperConfig(_MapperConfig): "allow_unmapped_annotations", ) + is_deferred = False registry: _RegistryType clsdict_view: _ClassDict collected_annotations: Dict[str, _CollectedAnnotation] @@ -532,13 +552,15 @@ class _ClassScanMapperConfig(_MapperConfig): self.classname, self.cls, registry._class_registry ) + self._setup_inheriting_mapper(mapper_kw) + self._extract_mappable_attributes() self._extract_declared_columns() self._setup_table(table) - self._setup_inheritance(mapper_kw) + self._setup_inheriting_columns(mapper_kw) self._early_mapping(mapper_kw) @@ -739,13 +761,7 @@ class _ClassScanMapperConfig(_MapperConfig): # need to do this all the way up the hierarchy first # (see #8190) - class_mapped = ( - base is not cls - and _declared_mapping_info(base) is not None - and not _get_immediate_cls_attr( - base, "_sa_decl_prepare_nocascade", strict=True - ) - ) + class_mapped = base is not cls and _is_supercls_for_inherits(base) local_attributes_for_class = self._cls_attr_resolver(base) @@ -1358,6 +1374,7 @@ class _ClassScanMapperConfig(_MapperConfig): if mapped_container is not None or annotation is None: try: value.declarative_scan( + self, self.registry, cls, originating_module, @@ -1558,11 +1575,8 @@ class _ClassScanMapperConfig(_MapperConfig): else: return manager.registry.metadata - def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None: - table = self.local_table + def _setup_inheriting_mapper(self, mapper_kw: _MapperKwArgs) -> None: cls = self.cls - table_args = self.table_args - declared_columns = self.declared_columns inherits = mapper_kw.get("inherits", None) @@ -1574,13 +1588,9 @@ class _ClassScanMapperConfig(_MapperConfig): c = _resolve_for_abstract_or_classical(base_) if c is None: continue - if _declared_mapping_info( - c - ) is not None and not _get_immediate_cls_attr( - c, "_sa_decl_prepare_nocascade", strict=True - ): - if c not in inherits_search: - inherits_search.append(c) + + if _is_supercls_for_inherits(c) and c not in inherits_search: + inherits_search.append(c) if inherits_search: if len(inherits_search) > 1: @@ -1594,6 +1604,12 @@ class _ClassScanMapperConfig(_MapperConfig): self.inherits = inherits + def _setup_inheriting_columns(self, mapper_kw: _MapperKwArgs) -> None: + table = self.local_table + cls = self.cls + table_args = self.table_args + declared_columns = self.declared_columns + if ( table is None and self.inherits is None @@ -1636,9 +1652,12 @@ class _ClassScanMapperConfig(_MapperConfig): if inherited_table.c[col.name] is col: continue raise exc.ArgumentError( - "Column '%s' on class %s conflicts with " - "existing column '%s'" - % (col, cls, inherited_table.c[col.name]) + f"Column '{col}' on class {cls.__name__} " + f"conflicts with existing column " + f"'{inherited_table.c[col.name]}'. If using " + f"Declarative, consider using the " + "use_existing_column parameter of mapped_column() " + "to resolve conflicts." ) if col.primary_key: raise exc.ArgumentError( @@ -1695,14 +1714,15 @@ class _ClassScanMapperConfig(_MapperConfig): mapper_args["inherits"] = self.inherits if self.inherits and not mapper_args.get("concrete", False): + # note the superclass is expected to have a Mapper assigned and + # not be a deferred config, as this is called within map() + inherited_mapper = class_mapper(self.inherits, False) + inherited_table = inherited_mapper.local_table + # single or joined inheritance # exclude any cols on the inherited table which are # not mapped on the parent class, to avoid # mapping columns specific to sibling/nephew classes - inherited_mapper = _declared_mapping_info(self.inherits) - assert isinstance(inherited_mapper, Mapper) - inherited_table = inherited_mapper.local_table - if "exclude_properties" not in mapper_args: mapper_args["exclude_properties"] = exclude_properties = { c.key @@ -1768,6 +1788,8 @@ def _as_dc_declaredattr( class _DeferredMapperConfig(_ClassScanMapperConfig): _cls: weakref.ref[Type[Any]] + is_deferred = True + _configs: util.OrderedDict[ weakref.ref[Type[Any]], _DeferredMapperConfig ] = util.OrderedDict() diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 55c7e9290b..56d6b2f6fe 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -64,6 +64,7 @@ if typing.TYPE_CHECKING: from .attributes import InstrumentedAttribute from .attributes import QueryableAttribute from .context import ORMCompileState + from .decl_base import _ClassScanMapperConfig from .mapper import Mapper from .properties import ColumnProperty from .properties import MappedColumn @@ -332,6 +333,7 @@ class CompositeProperty( @util.preload_module("sqlalchemy.orm.properties") def declarative_scan( self, + decl_scan: _ClassScanMapperConfig, registry: _RegistryType, cls: Type[Any], originating_module: Optional[str], diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 3d2f9708fc..18083241b2 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -84,6 +84,7 @@ if typing.TYPE_CHECKING: from .context import ORMCompileState from .context import QueryContext from .decl_api import RegistryType + from .decl_base import _ClassScanMapperConfig from .loading import _PopulatorDict from .mapper import Mapper from .path_registry import AbstractEntityRegistry @@ -157,6 +158,7 @@ class _IntrospectsAnnotations: def declarative_scan( self, + decl_scan: _ClassScanMapperConfig, registry: RegistryType, cls: Type[Any], originating_module: Optional[str], diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 1a5f0bd71d..e766fd06cd 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -28,6 +28,7 @@ from typing import TypeVar from . import attributes from . import strategy_options from .base import _DeclarativeMapped +from .base import class_mapper from .descriptor_props import CompositeProperty from .descriptor_props import ConcreteInheritedProperty from .descriptor_props import SynonymProperty @@ -66,6 +67,7 @@ if TYPE_CHECKING: from ._typing import _ORMColumnExprArgument from ._typing import _RegistryType from .base import Mapped + from .decl_base import _ClassScanMapperConfig from .mapper import Mapper from .session import Session from .state import _InstallLoaderCallableProto @@ -192,6 +194,7 @@ class ColumnProperty( def declarative_scan( self, + decl_scan: _ClassScanMapperConfig, registry: _RegistryType, cls: Type[Any], originating_module: Optional[str], @@ -531,6 +534,7 @@ class MappedColumn( "deferred_raiseload", "_attribute_options", "_has_dataclass_arguments", + "_use_existing_column", ) deferred: bool @@ -546,6 +550,8 @@ class MappedColumn( "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS ) + self._use_existing_column = kw.pop("use_existing_column", False) + self._has_dataclass_arguments = False if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS: @@ -592,6 +598,7 @@ class MappedColumn( new._attribute_options = self._attribute_options new._has_insert_default = self._has_insert_default new._has_dataclass_arguments = self._has_dataclass_arguments + new._use_existing_column = self._use_existing_column util.set_creation_order(new) return new @@ -635,6 +642,7 @@ class MappedColumn( def declarative_scan( self, + decl_scan: _ClassScanMapperConfig, registry: _RegistryType, cls: Type[Any], originating_module: Optional[str], @@ -645,6 +653,18 @@ class MappedColumn( is_dataclass_field: bool, ) -> None: column = self.column + + if self._use_existing_column and decl_scan.inherits: + if decl_scan.is_deferred: + raise sa_exc.ArgumentError( + "Can't use use_existing_column with deferred mappers" + ) + supercls_mapper = class_mapper(decl_scan.inherits, False) + + column = self.column = supercls_mapper.local_table.c.get( # type: ignore # noqa: E501 + key, column + ) + if column.key is None: column.key = key if column.name is None: diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 73d11e8800..4a9bcd711e 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -101,6 +101,7 @@ if typing.TYPE_CHECKING: from .base import Mapped from .clsregistry import _class_resolver from .clsregistry import _ModNS + from .decl_base import _ClassScanMapperConfig from .dependency import DependencyProcessor from .mapper import Mapper from .query import Query @@ -1723,6 +1724,7 @@ class RelationshipProperty( def declarative_scan( self, + decl_scan: _ClassScanMapperConfig, registry: _RegistryType, cls: Type[Any], originating_module: Optional[str], diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py index e143ad1277..53f518a27f 100644 --- a/test/ext/declarative/test_reflection.py +++ b/test/ext/declarative/test_reflection.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String @@ -7,11 +10,14 @@ from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import decl_api as decl from sqlalchemy.orm import declared_attr from sqlalchemy.orm import exc as orm_exc +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm.decl_base import _DeferredMapperConfig from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -33,6 +39,10 @@ class DeclarativeReflectionBase(fixtures.TablesTest): def teardown_test(self): clear_mappers() + @testing.fixture + def decl_base(self): + yield Base + class DeferredReflectBase(DeclarativeReflectionBase): def teardown_test(self): @@ -346,8 +356,8 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): Column("bar_data", String(30)), ) - def test_basic(self): - class Foo(DeferredReflection, fixtures.ComparableEntity, Base): + def test_basic(self, decl_base): + class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -360,8 +370,8 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): DeferredReflection.prepare(testing.db) self._roundtrip() - def test_add_subclass_column(self): - class Foo(DeferredReflection, fixtures.ComparableEntity, Base): + def test_add_subclass_column(self, decl_base): + class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -375,8 +385,40 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): DeferredReflection.prepare(testing.db) self._roundtrip() - def test_add_pk_column(self): - class Foo(DeferredReflection, fixtures.ComparableEntity, Base): + def test_add_subclass_mapped_column(self, decl_base): + class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): + __tablename__ = "foo" + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "foo", + } + + class Bar(Foo): + __mapper_args__ = {"polymorphic_identity": "bar"} + bar_data: Mapped[str] + + DeferredReflection.prepare(testing.db) + self._roundtrip() + + def test_subclass_mapped_column_no_existing(self, decl_base): + class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): + __tablename__ = "foo" + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "foo", + } + + with expect_raises_message( + exc.ArgumentError, + "Can't use use_existing_column with deferred mappers", + ): + + class Bar(Foo): + __mapper_args__ = {"polymorphic_identity": "bar"} + bar_data: Mapped[str] = mapped_column(use_existing_column=True) + + def test_add_pk_column(self, decl_base): + class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): __tablename__ = "foo" __mapper_args__ = { "polymorphic_on": "type", @@ -390,6 +432,21 @@ class DeferredSingleInhReflectionTest(DeferredInhReflectBase): DeferredReflection.prepare(testing.db) self._roundtrip() + def test_add_pk_mapped_column(self, decl_base): + class Foo(DeferredReflection, fixtures.ComparableEntity, decl_base): + __tablename__ = "foo" + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "foo", + } + id: Mapped[int] = mapped_column(primary_key=True) + + class Bar(Foo): + __mapper_args__ = {"polymorphic_identity": "bar"} + + DeferredReflection.prepare(testing.db) + self._roundtrip() + class DeferredJoinedInhReflectionTest(DeferredInhReflectBase): @classmethod diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py index 9829f42333..f3506a3100 100644 --- a/test/orm/declarative/test_inheritance.py +++ b/test/orm/declarative/test_inheritance.py @@ -1,5 +1,6 @@ import sqlalchemy as sa from sqlalchemy import ForeignKey +from sqlalchemy import Identity from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String @@ -9,7 +10,10 @@ from sqlalchemy.orm import close_all_sessions from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session from sqlalchemy.orm import with_polymorphic from sqlalchemy.orm.decl_api import registry from sqlalchemy.testing import assert_raises @@ -805,11 +809,10 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): [Person.__table__.c.name, Person.__table__.c.primary_language], ) - @testing.skip_if( - lambda: testing.against("oracle"), - "Test has an empty insert in it at the moment", - ) - def test_columns_single_inheritance_conflict_resolution(self): + @testing.variation("decl_type", ["legacy", "use_existing_column"]) + def test_columns_single_inheritance_conflict_resolution( + self, connection, decl_base, decl_type + ): """Test that a declared_attr can return the existing column and it will be ignored. this allows conditional columns to be added. @@ -817,18 +820,25 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): """ - class Person(Base): + class Person(decl_base): __tablename__ = "person" - id = Column(Integer, primary_key=True) + id = Column(Integer, Identity(), primary_key=True) class Engineer(Person): """single table inheritance""" - @declared_attr - def target_id(cls): - return cls.__table__.c.get( - "target_id", Column(Integer, ForeignKey("other.id")) + if decl_type.legacy: + + @declared_attr + def target_id(cls): + return cls.__table__.c.get( + "target_id", Column(Integer, ForeignKey("other.id")) + ) + + elif decl_type.use_existing_column: + target_id: Mapped[int] = mapped_column( + ForeignKey("other.id"), use_existing_column=True ) @declared_attr @@ -839,19 +849,26 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): """single table inheritance""" - @declared_attr - def target_id(cls): - return cls.__table__.c.get( - "target_id", Column(Integer, ForeignKey("other.id")) + if decl_type.legacy: + + @declared_attr + def target_id(cls): + return cls.__table__.c.get( + "target_id", Column(Integer, ForeignKey("other.id")) + ) + + elif decl_type.use_existing_column: + target_id: Mapped[int] = mapped_column( + ForeignKey("other.id"), use_existing_column=True ) @declared_attr def target(cls): return relationship("Other") - class Other(Base): + class Other(decl_base): __tablename__ = "other" - id = Column(Integer, primary_key=True) + id = Column(Integer, Identity(), primary_key=True) is_( Engineer.target_id.property.columns[0], @@ -861,22 +878,25 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): Manager.target_id.property.columns[0], Person.__table__.c.target_id ) # do a brief round trip on this - Base.metadata.create_all(testing.db) - session = fixture_session() - o1, o2 = Other(), Other() - session.add_all( - [Engineer(target=o1), Manager(target=o2), Manager(target=o1)] - ) - session.commit() - eq_(session.query(Engineer).first().target, o1) + decl_base.metadata.create_all(connection) + with Session(connection) as session: + o1, o2 = Other(), Other() + session.add_all( + [Engineer(target=o1), Manager(target=o2), Manager(target=o1)] + ) + session.commit() + eq_(session.query(Engineer).first().target, o1) - def test_columns_single_inheritance_conflict_resolution_pk(self): + @testing.variation("decl_type", ["legacy", "use_existing_column"]) + def test_columns_single_inheritance_conflict_resolution_pk( + self, decl_base, decl_type + ): """Test #2472 in terms of a primary key column. This is #4352. """ - class Person(Base): + class Person(decl_base): __tablename__ = "person" id = Column(Integer, primary_key=True) @@ -886,20 +906,34 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): """single table inheritance""" - @declared_attr - def target_id(cls): - return cls.__table__.c.get( - "target_id", Column(Integer, primary_key=True) + if decl_type.legacy: + + @declared_attr + def target_id(cls): + return cls.__table__.c.get( + "target_id", Column(Integer, primary_key=True) + ) + + elif decl_type.use_existing_column: + target_id: Mapped[int] = mapped_column( + primary_key=True, use_existing_column=True ) class Manager(Person): """single table inheritance""" - @declared_attr - def target_id(cls): - return cls.__table__.c.get( - "target_id", Column(Integer, primary_key=True) + if decl_type.legacy: + + @declared_attr + def target_id(cls): + return cls.__table__.c.get( + "target_id", Column(Integer, primary_key=True) + ) + + elif decl_type.use_existing_column: + target_id: Mapped[int] = mapped_column( + primary_key=True, use_existing_column=True ) is_( @@ -910,20 +944,30 @@ class DeclarativeInheritanceTest(DeclarativeTestBase): Manager.target_id.property.columns[0], Person.__table__.c.target_id ) - def test_columns_single_inheritance_cascading_resolution_pk(self): + @testing.variation("decl_type", ["legacy", "use_existing_column"]) + def test_columns_single_inheritance_cascading_resolution_pk( + self, decl_type + ): """An additional test for #4352 in terms of the requested use case.""" class TestBase(Base): __abstract__ = True - @declared_attr.cascading - def id(cls): - col_val = None - if TestBase not in cls.__bases__: - col_val = cls.__table__.c.get("id") - if col_val is None: - col_val = Column(Integer, primary_key=True) - return col_val + if decl_type.legacy: + + @declared_attr.cascading + def id(cls): # noqa: A001 + col_val = None + if TestBase not in cls.__bases__: + col_val = cls.__table__.c.get("id") + if col_val is None: + col_val = Column(Integer, primary_key=True) + return col_val + + elif decl_type.use_existing_column: + id: Mapped[int] = mapped_column( # noqa: A001 + primary_key=True, use_existing_column=True + ) class Person(TestBase): """single table base class""" -- 2.47.2