]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
coerce elements in mapper.primary_key, process in __mapper_args__
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Feb 2023 21:35:21 +0000 (16:35 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 5 Feb 2023 15:39:01 +0000 (10:39 -0500)
Repaired ORM Declarative mappings to allow for the
:paramref:`_orm.Mapper.primary_key` parameter to be specified within
``__mapper_args__`` when using :func:`_orm.mapped_column`. Despite this
usage being directly in the 2.0 documentation, the :class:`_orm.Mapper` was
not accepting the :func:`_orm.mapped_column` construct in this context. Ths
feature was already working for the :paramref:`_orm.Mapper.version_id_col`
and :paramref:`_orm.Mapper.polymorphic_on` parameters.

As part of this change, the ``__mapper_args__`` attribute may be specified
without using :func:`_orm.declared_attr` on a non-mapped mixin class,
including a ``"primary_key"`` entry that refers to :class:`_schema.Column`
or :func:`_orm.mapped_column` objects locally present on the mixin;
Declarative will also translate these columns into the correct ones for a
particular mapped class. This again was working already for the
:paramref:`_orm.Mapper.version_id_col` and
:paramref:`_orm.Mapper.polymorphic_on` parameters.  Additionally,
elements within ``"primary_key"`` may be indicated as string names of
existing mapped properties.

Fixes: #9240
Change-Id: Ie2000273289fa23e0af21ef9c6feb3962a8b848c

doc/build/changelog/unreleased_20/9240.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/mapper.py
test/orm/declarative/test_basic.py

diff --git a/doc/build/changelog/unreleased_20/9240.rst b/doc/build/changelog/unreleased_20/9240.rst
new file mode 100644 (file)
index 0000000..23e807f
--- /dev/null
@@ -0,0 +1,22 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9240
+
+    Repaired ORM Declarative mappings to allow for the
+    :paramref:`_orm.Mapper.primary_key` parameter to be specified within
+    ``__mapper_args__`` when using :func:`_orm.mapped_column`. Despite this
+    usage being directly in the 2.0 documentation, the :class:`_orm.Mapper` was
+    not accepting the :func:`_orm.mapped_column` construct in this context. Ths
+    feature was already working for the :paramref:`_orm.Mapper.version_id_col`
+    and :paramref:`_orm.Mapper.polymorphic_on` parameters.
+
+    As part of this change, the ``__mapper_args__`` attribute may be specified
+    without using :func:`_orm.declared_attr` on a non-mapped mixin class,
+    including a ``"primary_key"`` entry that refers to :class:`_schema.Column`
+    or :func:`_orm.mapped_column` objects locally present on the mixin;
+    Declarative will also translate these columns into the correct ones for a
+    particular mapped class. This again was working already for the
+    :paramref:`_orm.Mapper.version_id_col` and
+    :paramref:`_orm.Mapper.polymorphic_on` parameters.  Additionally,
+    elements within ``"primary_key"`` may be indicated as string names of
+    existing mapped properties.
index a858f12cb947c6c1adb564453e0c603537b45a61..37fa964b844e6f98f3c4ee8af29c1d92d014e2ec 100644 (file)
@@ -1721,6 +1721,12 @@ class _ClassScanMapperConfig(_MapperConfig):
                 v = mapper_args[k]
                 mapper_args[k] = self.column_copies.get(v, v)
 
+        if "primary_key" in mapper_args:
+            mapper_args["primary_key"] = [
+                self.column_copies.get(v, v)
+                for v in util.to_list(mapper_args["primary_key"])
+            ]
+
         if "inherits" in mapper_args:
             inherits_arg = mapper_args["inherits"]
             if isinstance(inherits_arg, Mapper):
index a3b209e4a6507042d40de7394df0bc892c4f6375..660c616912affdea28588f3bfe418428f6857699 100644 (file)
@@ -83,6 +83,7 @@ from ..sql import util as sql_util
 from ..sql import visitors
 from ..sql.cache_key import MemoizedHasCacheKey
 from ..sql.elements import KeyedColumnElement
+from ..sql.schema import Column
 from ..sql.schema import Table
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 from ..util import HasMemoized
@@ -112,7 +113,6 @@ if TYPE_CHECKING:
     from ..sql.base import ReadOnlyColumnCollection
     from ..sql.elements import ColumnClause
     from ..sql.elements import ColumnElement
-    from ..sql.schema import Column
     from ..sql.selectable import FromClause
     from ..util import OrderedSet
 
@@ -650,11 +650,15 @@ class Mapper(
                :ref:`orm_mapping_classes_toplevel`
 
         :param primary_key: A list of :class:`_schema.Column`
-           objects which define
+           objects, or alternatively string names of attribute names which
+           refer to :class:`_schema.Column`, which define
            the primary key to be used against this mapper's selectable unit.
            This is normally simply the primary key of the ``local_table``, but
            can be overridden here.
 
+           .. versionchanged:: 2.0.2 :paramref:`_orm.Mapper.primary_key`
+              arguments may be indicated as string attribute names as well.
+
            .. seealso::
 
                 :ref:`mapper_primary_key` - background and example use
@@ -1557,6 +1561,29 @@ class Mapper(
 
         self.__dict__.pop("_configure_failed", None)
 
+    def _str_arg_to_mapped_col(self, argname: str, key: str) -> Column[Any]:
+        try:
+            prop = self._props[key]
+        except KeyError as err:
+            raise sa_exc.ArgumentError(
+                f"Can't determine {argname} column '{key}' - "
+                "no attribute is mapped to this name."
+            ) from err
+        try:
+            expr = prop.expression
+        except AttributeError as ae:
+            raise sa_exc.ArgumentError(
+                f"Can't determine {argname} column '{key}'; "
+                "property does not refer to a single mapped Column"
+            ) from ae
+        if not isinstance(expr, Column):
+            raise sa_exc.ArgumentError(
+                f"Can't determine {argname} column '{key}'; "
+                "property does not refer to a single "
+                "mapped Column"
+            )
+        return expr
+
     def _configure_pks(self) -> None:
         self.tables = sql_util.find_tables(self.persist_selectable)
 
@@ -1585,10 +1612,28 @@ class Mapper(
                 all_cols
             )
 
+        if self._primary_key_argument:
+
+            coerced_pk_arg = [
+                self._str_arg_to_mapped_col("primary_key", c)
+                if isinstance(c, str)
+                else c
+                for c in (
+                    coercions.expect(  # type: ignore
+                        roles.DDLConstraintColumnRole,
+                        coerce_pk,
+                        argname="primary_key",
+                    )
+                    for coerce_pk in self._primary_key_argument
+                )
+            ]
+        else:
+            coerced_pk_arg = None
+
         # if explicit PK argument sent, add those columns to the
         # primary key mappings
-        if self._primary_key_argument:
-            for k in self._primary_key_argument:
+        if coerced_pk_arg:
+            for k in coerced_pk_arg:
                 if k.table not in self._pks_by_table:
                     self._pks_by_table[k.table] = util.OrderedSet()
                 self._pks_by_table[k.table].add(k)
@@ -1625,12 +1670,12 @@ class Mapper(
             # determine primary key from argument or persist_selectable pks
             primary_key: Collection[ColumnElement[Any]]
 
-            if self._primary_key_argument:
+            if coerced_pk_arg:
                 primary_key = [
                     cc if cc is not None else c
                     for cc, c in (
                         (self.persist_selectable.corresponding_column(c), c)
-                        for c in self._primary_key_argument
+                        for c in coerced_pk_arg
                     )
                 ]
             else:
index e2108f8886fda1ee06d7c446aef8da06531228be..45f0d4200bb9d445c6ab3fb1f2bc790ddcac733d 100644 (file)
@@ -30,6 +30,7 @@ from sqlalchemy.orm import deferred
 from sqlalchemy.orm import descriptor_props
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import MappedColumn
 from sqlalchemy.orm import Mapper
@@ -188,6 +189,251 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
             ):
                 Base.__init__(fs, x=5)
 
+    @testing.variation("argument", ["version_id_col", "polymorphic_on"])
+    @testing.variation("column_type", ["anno", "non_anno", "plain_column"])
+    def test_mapped_column_version_poly_arg(
+        self, decl_base, column_type, argument
+    ):
+        """test #9240"""
+
+        if column_type.anno:
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                a: Mapped[int] = mapped_column(primary_key=True)
+                b: Mapped[int] = mapped_column()
+                c: Mapped[str] = mapped_column()
+
+                if argument.version_id_col:
+                    __mapper_args__ = {"version_id_col": b}
+                elif argument.polymorphic_on:
+                    __mapper_args__ = {"polymorphic_on": c}
+                else:
+                    argument.fail()
+
+        elif column_type.non_anno:
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                a = mapped_column(Integer, primary_key=True)
+                b = mapped_column(Integer)
+                c = mapped_column(String)
+
+                if argument.version_id_col:
+                    __mapper_args__ = {"version_id_col": b}
+                elif argument.polymorphic_on:
+                    __mapper_args__ = {"polymorphic_on": c}
+                else:
+                    argument.fail()
+
+        elif column_type.plain_column:
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                a = Column(Integer, primary_key=True)
+                b = Column(Integer)
+                c = Column(String)
+
+                if argument.version_id_col:
+                    __mapper_args__ = {"version_id_col": b}
+                elif argument.polymorphic_on:
+                    __mapper_args__ = {"polymorphic_on": c}
+                else:
+                    argument.fail()
+
+        else:
+            column_type.fail()
+
+        if argument.version_id_col:
+            assert A.__mapper__.version_id_col is A.__table__.c.b
+        elif argument.polymorphic_on:
+            assert A.__mapper__.polymorphic_on is A.__table__.c.c
+        else:
+            argument.fail()
+
+    @testing.variation(
+        "pk_type", ["single", "tuple", "list", "single_str", "list_str"]
+    )
+    @testing.variation("column_type", ["anno", "non_anno", "plain_column"])
+    def test_mapped_column_pk_arg(self, decl_base, column_type, pk_type):
+        """test #9240"""
+
+        if column_type.anno:
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                a: Mapped[int] = mapped_column()
+                b: Mapped[int] = mapped_column()
+
+                if pk_type.single:
+                    __mapper_args__ = {"primary_key": a}
+                elif pk_type.tuple:
+                    __mapper_args__ = {"primary_key": (a, b)}
+                elif pk_type.list:
+                    __mapper_args__ = {"primary_key": [a, b]}
+                elif pk_type.single_str:
+                    __mapper_args__ = {"primary_key": "a"}
+                elif pk_type.list_str:
+                    __mapper_args__ = {"primary_key": ["a", "b"]}
+                else:
+                    pk_type.fail()
+
+        elif column_type.non_anno:
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                a = mapped_column(Integer)
+                b = mapped_column(Integer)
+
+                if pk_type.single:
+                    __mapper_args__ = {"primary_key": a}
+                elif pk_type.tuple:
+                    __mapper_args__ = {"primary_key": (a, b)}
+                elif pk_type.list:
+                    __mapper_args__ = {"primary_key": [a, b]}
+                elif pk_type.single_str:
+                    __mapper_args__ = {"primary_key": "a"}
+                elif pk_type.list_str:
+                    __mapper_args__ = {"primary_key": ["a", "b"]}
+                else:
+                    pk_type.fail()
+
+        elif column_type.plain_column:
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                a = Column(Integer)
+                b = Column(Integer)
+
+                if pk_type.single:
+                    __mapper_args__ = {"primary_key": a}
+                elif pk_type.tuple:
+                    __mapper_args__ = {"primary_key": (a, b)}
+                elif pk_type.list:
+                    __mapper_args__ = {"primary_key": [a, b]}
+                elif pk_type.single_str:
+                    __mapper_args__ = {"primary_key": "a"}
+                elif pk_type.list_str:
+                    __mapper_args__ = {"primary_key": ["a", "b"]}
+                else:
+                    pk_type.fail()
+
+        else:
+            column_type.fail()
+
+        if pk_type.single or pk_type.single_str:
+            assert A.__mapper__.primary_key[0] is A.__table__.c.a
+        else:
+            assert A.__mapper__.primary_key[0] is A.__table__.c.a
+            assert A.__mapper__.primary_key[1] is A.__table__.c.b
+
+    def test_mapper_pk_arg_degradation_no_col(self, decl_base):
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            "Can't determine primary_key column 'q' - no attribute is "
+            "mapped to this name.",
+        ):
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                a: Mapped[int] = mapped_column()
+                b: Mapped[int] = mapped_column()
+
+                __mapper_args__ = {"primary_key": "q"}
+
+    @testing.variation("proptype", ["relationship", "colprop"])
+    def test_mapper_pk_arg_degradation_is_not_a_col(self, decl_base, proptype):
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            "Can't determine primary_key column 'b'; property does "
+            "not refer to a single mapped Column",
+        ):
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                a: Mapped[int] = mapped_column(Integer)
+
+                if proptype.colprop:
+                    b: Mapped[int] = column_property(a + 5)
+                elif proptype.relationship:
+                    b = relationship("B")
+                else:
+                    proptype.fail()
+
+                __mapper_args__ = {"primary_key": "b"}
+
+    @testing.variation(
+        "argument", ["version_id_col", "polymorphic_on", "primary_key"]
+    )
+    @testing.variation("argtype", ["callable", "fixed"])
+    @testing.variation("column_type", ["mapped_column", "plain_column"])
+    def test_mapped_column_pk_arg_via_mixin(
+        self, decl_base, argtype, column_type, argument
+    ):
+        """test #9240"""
+
+        class Mixin:
+            if column_type.mapped_column:
+                a: Mapped[int] = mapped_column()
+                b: Mapped[int] = mapped_column()
+                c: Mapped[str] = mapped_column()
+            elif column_type.plain_column:
+                a = Column(Integer)
+                b = Column(Integer)
+                c = Column(String)
+            else:
+                column_type.fail()
+
+            if argtype.callable:
+
+                @declared_attr.directive
+                @classmethod
+                def __mapper_args__(cls):
+                    if argument.primary_key:
+                        return {"primary_key": [cls.a, cls.b]}
+                    elif argument.version_id_col:
+                        return {"version_id_col": cls.b, "primary_key": cls.a}
+                    elif argument.polymorphic_on:
+                        return {"polymorphic_on": cls.c, "primary_key": cls.a}
+                    else:
+                        argument.fail()
+
+            elif argtype.fixed:
+                if argument.primary_key:
+                    __mapper_args__ = {"primary_key": [a, b]}
+                elif argument.version_id_col:
+                    __mapper_args__ = {"primary_key": a, "version_id_col": b}
+                elif argument.polymorphic_on:
+                    __mapper_args__ = {"primary_key": a, "polymorphic_on": c}
+                else:
+                    argument.fail()
+
+            else:
+                argtype.fail()
+
+        class A(Mixin, decl_base):
+            __tablename__ = "a"
+
+        if argument.primary_key:
+            assert A.__mapper__.primary_key[0] is A.__table__.c.a
+            assert A.__mapper__.primary_key[1] is A.__table__.c.b
+        elif argument.version_id_col:
+            assert A.__mapper__.version_id_col is A.__table__.c.b
+        elif argument.polymorphic_on:
+            assert A.__mapper__.polymorphic_on is A.__table__.c.c
+        else:
+            argtype.fail()
+
     def test_dispose_attrs(self):
         reg = registry()