]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support renamed symbols in annotation scans
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Nov 2022 15:04:13 +0000 (11:04 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 4 Nov 2022 17:38:06 +0000 (13:38 -0400)
Added support in ORM declarative annotations for class names specified for
:func:`_orm.relationship`, as well as the name of the :class:`_orm.Mapped`
symbol itself, to be different names than their direct class name, to
support scenarios such as where :class:`_orm.Mapped` is imported as
``from sqlalchemy.orm import Mapped as M``, or where related class names
are imported with an alternate name in a similar fashion. Additionally, a
target class name given as the lead argument for :func:`_orm.relationship`
will always supersede the name given in the left hand annotation, so that
otherwise un-importable names that also don't match the class name can
still be used in annotations.

Fixes: #8759
Change-Id: I74a00de7e1a45bf62dad50fd385bb75cf343f9f3

doc/build/changelog/unreleased_20/8759.rst [new file with mode: 0644]
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/8759.rst b/doc/build/changelog/unreleased_20/8759.rst
new file mode 100644 (file)
index 0000000..3f9f482
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: bug, orm, declarative
+    :tickets: 8759
+
+    Added support in ORM declarative annotations for class names specified for
+    :func:`_orm.relationship`, as well as the name of the :class:`_orm.Mapped`
+    symbol itself, to be different names than their direct class name, to
+    support scenarios such as where :class:`_orm.Mapped` is imported as
+    ``from sqlalchemy.orm import Mapped as M``, or where related class names
+    are imported with an alternate name in a similar fashion. Additionally, a
+    target class name given as the lead argument for :func:`_orm.relationship`
+    will always supersede the name given in the left hand annotation, so that
+    otherwise un-importable names that also don't match the class name can
+    still be used in annotations.
index 6d388a630b59b2226d9c47f6740f9950e9926283..e0922a5380f7a3ae90990726bdb5a9b8ef6be740 100644 (file)
@@ -88,6 +88,7 @@ from ..sql.util import selectables_overlap
 from ..sql.util import visit_binary_product
 from ..util.typing import de_optionalize_union_types
 from ..util.typing import Literal
+from ..util.typing import resolve_name_to_real_class_name
 
 if typing.TYPE_CHECKING:
     from ._typing import _EntityType
@@ -1729,6 +1730,7 @@ class RelationshipProperty(
                 return
 
         argument = extracted_mapped_annotation
+        assert originating_module is not None
 
         is_write_only = mapped_container is not None and issubclass(
             mapped_container, WriteOnlyMapped
@@ -1765,7 +1767,10 @@ class RelationshipProperty(
                     type_arg = argument.__args__[0]  # type: ignore
                 if hasattr(type_arg, "__forward_arg__"):
                     str_argument = type_arg.__forward_arg__
-                    argument = str_argument
+
+                    argument = resolve_name_to_real_class_name(
+                        str_argument, originating_module
+                    )
                 else:
                     argument = type_arg
             else:
@@ -1775,6 +1780,10 @@ class RelationshipProperty(
         elif hasattr(argument, "__forward_arg__"):
             argument = argument.__forward_arg__  # type: ignore
 
+            argument = resolve_name_to_real_class_name(
+                argument, originating_module
+            )
+
             # we don't allow the collection class to be a
             # __forward_arg__ right now, so if we see a forward arg here,
             # we know there was no collection class either
@@ -1785,7 +1794,14 @@ class RelationshipProperty(
             ):
                 self.uselist = False
 
-        self.argument = cast("_RelationshipArgumentType[_T]", argument)
+        # ticket #8759
+        # if a lead argument was given to relationship(), like
+        # `relationship("B")`, use that, don't replace it with class we
+        # found in the annotation.  The declarative_scan() method call here is
+        # still useful, as we continue to derive collection type and do
+        # checking of the annotation in any case.
+        if self.argument is None:
+            self.argument = cast("_RelationshipArgumentType[_T]", argument)
 
     @util.preload_module("sqlalchemy.orm.mapper")
     def _setup_entity(self, __argument: Any = None) -> None:
index e64c936363fa510b23cfc8d4631e60af3b58be71..6cd98f5ea078c5777af543652a7d9e2925dd219f 100644 (file)
@@ -78,6 +78,7 @@ from ..sql.elements import KeyedColumnElement
 from ..sql.selectable import FromClause
 from ..util.langhelpers import MemoizedSlots
 from ..util.typing import de_stringify_annotation
+from ..util.typing import eval_name_only
 from ..util.typing import is_origin_of_cls
 from ..util.typing import Literal
 from ..util.typing import typing_get_origin
@@ -2034,35 +2035,75 @@ def _is_mapped_annotation(
         return is_origin_of_cls(annotated, _MappedAnnotationBase)
 
 
-def _cleanup_mapped_str_annotation(annotation: str) -> str:
+class _CleanupError(Exception):
+    pass
+
+
+def _cleanup_mapped_str_annotation(
+    annotation: str, originating_module: str
+) -> str:
     # fix up an annotation that comes in as the form:
     # 'Mapped[List[Address]]'  so that it instead looks like:
     # 'Mapped[List["Address"]]' , which will allow us to get
     # "Address" as a string
 
+    # additionally, resolve symbols for these names since this is where
+    # we'd have to do it
+
     inner: Optional[Match[str]]
 
     mm = re.match(r"^(.+?)\[(.+)\]$", annotation)
-    if mm and mm.group(1) in ("Mapped", "WriteOnlyMapped", "DynamicMapped"):
-        stack = []
-        inner = mm
-        while True:
-            stack.append(inner.group(1))
-            g2 = inner.group(2)
-            inner = re.match(r"^(.+?)\[(.+)\]$", g2)
-            if inner is None:
-                stack.append(g2)
-                break
-
-        # stack: ['Mapped', 'List', 'Address']
-        if not re.match(r"""^["'].*["']$""", stack[-1]):
-            stripchars = "\"' "
-            stack[-1] = ", ".join(
-                f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",")
-            )
-            # stack: ['Mapped', 'List', '"Address"']
 
-            annotation = "[".join(stack) + ("]" * (len(stack) - 1))
+    if not mm:
+        return annotation
+
+    # ticket #8759.  Resolve the Mapped name to a real symbol.
+    # originally this just checked the name.
+    try:
+        obj = eval_name_only(mm.group(1), originating_module)
+    except NameError as ne:
+        raise _CleanupError(
+            f'For annotation "{annotation}", could not resolve '
+            f'container type "{mm.group(1)}".  '
+            "Please ensure this type is imported at the module level "
+            "outside of TYPE_CHECKING blocks"
+        ) from ne
+
+    try:
+        if issubclass(obj, _MappedAnnotationBase):
+            real_symbol = obj.__name__
+        else:
+            return annotation
+    except TypeError:
+        # avoid isinstance(obj, type) check, just catch TypeError
+        return annotation
+
+    # note: if one of the codepaths above didn't define real_symbol and
+    # then didn't return, real_symbol raises UnboundLocalError
+    # which is actually a NameError, and the calling routines don't
+    # notice this since they are catching NameError anyway.   Just in case
+    # this is being modified in the future, something to be aware of.
+
+    stack = []
+    inner = mm
+    while True:
+        stack.append(real_symbol if mm is inner else inner.group(1))
+        g2 = inner.group(2)
+        inner = re.match(r"^(.+?)\[(.+)\]$", g2)
+        if inner is None:
+            stack.append(g2)
+            break
+
+    # stack: ['Mapped', 'List', 'Address']
+    if not re.match(r"""^["'].*["']$""", stack[-1]):
+        stripchars = "\"' "
+        stack[-1] = ", ".join(
+            f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",")
+        )
+        # stack: ['Mapped', 'List', '"Address"']
+
+        annotation = "[".join(stack) + ("]" * (len(stack) - 1))
+
     return annotation
 
 
@@ -2101,12 +2142,18 @@ def _extract_mapped_subtype(
             originating_module,
             _cleanup_mapped_str_annotation,
         )
+    except _CleanupError as ce:
+        raise sa_exc.ArgumentError(
+            f"Could not interpret annotation {raw_annotation}.  "
+            "Check that it uses names that are correctly imported at the "
+            "module level. See chained stack trace for more hints."
+        ) from ce
     except NameError as ne:
         if raiseerr and "Mapped[" in raw_annotation:  # type: ignore
             raise sa_exc.ArgumentError(
                 f"Could not interpret annotation {raw_annotation}.  "
-                "Check that it's not using names that might not be imported "
-                "at the module level.  See chained stack trace for more hints."
+                "Check that it uses names that are correctly imported at the "
+                "module level. See chained stack trace for more hints."
             ) from ne
 
         annotated = raw_annotation  # type: ignore
index b5d918a74e4db5677c60f871d158596b07f44495..e4674a44cb03ca4aee1aa722ae390f99b952b229 100644 (file)
@@ -85,7 +85,7 @@ def de_stringify_annotation(
     cls: Type[Any],
     annotation: _AnnotationScanType,
     originating_module: str,
-    str_cleanup_fn: Optional[Callable[[str], str]] = None,
+    str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
 ) -> Type[Any]:
     """Resolve annotations that may be string based into real objects.
 
@@ -109,26 +109,65 @@ def de_stringify_annotation(
 
     if isinstance(annotation, str):
         if str_cleanup_fn:
-            annotation = str_cleanup_fn(annotation)
-        base_globals: "Dict[str, Any]" = getattr(
-            sys.modules.get(originating_module, None), "__dict__", {}
-        )
-
-        try:
-            annotation = eval(annotation, base_globals, None)
-        except NameError as err:
-            # breakpoint()
-            raise NameError(
-                f"Could not de-stringify annotation {annotation}"
-            ) from err
+            annotation = str_cleanup_fn(annotation, originating_module)
+
+        annotation = eval_expression(annotation, originating_module)
     return annotation  # type: ignore
 
 
+def eval_expression(expression: str, module_name: str) -> Any:
+    try:
+        base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
+    except KeyError as ke:
+        raise NameError(
+            f"Module {module_name} isn't present in sys.modules; can't "
+            f"evaluate expression {expression}"
+        ) from ke
+    try:
+        annotation = eval(expression, base_globals, None)
+    except Exception as err:
+        raise NameError(
+            f"Could not de-stringify annotation {expression}"
+        ) from err
+    else:
+        return annotation
+
+
+def eval_name_only(name: str, module_name: str) -> Any:
+
+    try:
+        base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
+    except KeyError as ke:
+        raise NameError(
+            f"Module {module_name} isn't present in sys.modules; can't "
+            f"resolve name {name}"
+        ) from ke
+
+    # name only, just look in globals.  eval() works perfectly fine here,
+    # however we are seeking to have this be faster, as this occurs for
+    # every Mapper[] keyword, etc. depending on configuration
+    try:
+        return base_globals[name]
+    except KeyError as ke:
+        raise NameError(
+            f"Could not locate name {name} in module {module_name}"
+        ) from ke
+
+
+def resolve_name_to_real_class_name(name: str, module_name: str) -> str:
+    try:
+        obj = eval_name_only(name, module_name)
+    except NameError:
+        return name
+    else:
+        return getattr(obj, "__name__", name)
+
+
 def de_stringify_union_elements(
     cls: Type[Any],
     annotation: _AnnotationScanType,
     originating_module: str,
-    str_cleanup_fn: Optional[Callable[[str], str]] = None,
+    str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
 ) -> Type[Any]:
     return make_union_type(
         *[
index 03e41e7cfa913f781a410f102b0f0050d507611d..d66b08f4e0d64715dc3a59b18ba301c4e71029dd 100644 (file)
@@ -4,6 +4,7 @@ from decimal import Decimal
 from typing import List
 from typing import Optional
 from typing import Set
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 import uuid
@@ -31,6 +32,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.util import compat
+from .test_typed_mapping import expect_annotation_syntax_error
 from .test_typed_mapping import MappedColumnTest as _MappedColumnTest
 from .test_typed_mapping import RelationshipLHSTest as _RelationshipLHSTest
 from .test_typed_mapping import (
@@ -45,8 +47,97 @@ having ``from __future__ import annotations`` in effect.
 
 _R = TypeVar("_R")
 
+M = Mapped
+
+
+class M3:
+    pass
+
 
 class MappedColumnTest(_MappedColumnTest):
+    def test_indirect_mapped_name_module_level(self, decl_base):
+        """test #8759
+
+
+        Note that M by definition has to be at the module level to be
+        valid, and not locally declared here, this is in accordance with
+        mypy::
+
+
+            def make_class() -> None:
+                ll = list
+
+                x: ll[int] = [1, 2, 3]
+
+        Will return::
+
+            $ mypy test3.py
+            test3.py:4: error: Variable "ll" is not valid as a type  [valid-type]
+            test3.py:4: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
+            Found 1 error in 1 file (checked 1 source file)
+
+        Whereas the correct form is::
+
+            ll = list
+
+            def make_class() -> None:
+
+                x: ll[int] = [1, 2, 3]
+
+
+        """  # noqa: E501
+
+        class Foo(decl_base):
+            __tablename__ = "foo"
+
+            id: M[int] = mapped_column(primary_key=True)
+
+            data: M[int] = mapped_column()
+
+            data2: M[int]
+
+        self.assert_compile(
+            select(Foo), "SELECT foo.id, foo.data, foo.data2 FROM foo"
+        )
+
+    def test_indirect_mapped_name_local_level(self, decl_base):
+        """test #8759.
+
+        this should raise an error.
+
+        """
+
+        M2 = Mapped
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            r"Could not interpret annotation M2\[int\].  Check that it "
+            "uses names that are correctly imported at the module level.",
+        ):
+
+            class Foo(decl_base):
+                __tablename__ = "foo"
+
+                id: M2[int] = mapped_column(primary_key=True)
+
+                data2: M2[int]
+
+    def test_indirect_mapped_name_itswrong(self, decl_base):
+        """test #8759.
+
+        this should raise an error.
+
+        """
+
+        with expect_annotation_syntax_error("Foo.id"):
+
+            class Foo(decl_base):
+                __tablename__ = "foo"
+
+                id: M3[int] = mapped_column(primary_key=True)
+
+                data2: M3[int]
+
     def test_unions(self):
         our_type = Numeric(10, 2)
 
@@ -394,6 +485,124 @@ class RelationshipLHSTest(_RelationshipLHSTest):
         a1.bs.append(b1)
         is_(a1, b1.a)
 
+    @testing.combinations(
+        "include_relationship",
+        "no_relationship",
+        argnames="include_relationship",
+    )
+    @testing.combinations(
+        "direct_name", "indirect_name", argnames="indirect_name"
+    )
+    def test_indirect_name_collection(
+        self, decl_base, include_relationship, indirect_name
+    ):
+        """test #8759"""
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+
+        global B_
+        B_ = B
+
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column()
+
+            if indirect_name == "indirect_name":
+                if include_relationship == "include_relationship":
+                    bs: Mapped[List[B_]] = relationship("B")
+                else:
+                    bs: Mapped[List[B_]] = relationship()
+            else:
+                if include_relationship == "include_relationship":
+                    bs: Mapped[List[B]] = relationship("B")
+                else:
+                    bs: Mapped[List[B]] = relationship()
+
+        self.assert_compile(
+            select(A).join(A.bs),
+            "SELECT a.id, a.data FROM a JOIN b ON a.id = b.a_id",
+        )
+
+    @testing.combinations(
+        "include_relationship",
+        "no_relationship",
+        argnames="include_relationship",
+    )
+    @testing.combinations(
+        "direct_name", "indirect_name", argnames="indirect_name"
+    )
+    def test_indirect_name_scalar(
+        self, decl_base, include_relationship, indirect_name
+    ):
+        """test #8759"""
+
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column()
+
+        global A_
+        A_ = A
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+
+            if indirect_name == "indirect_name":
+                if include_relationship == "include_relationship":
+                    a: Mapped[A_] = relationship("A")
+                else:
+                    a: Mapped[A_] = relationship()
+            else:
+                if include_relationship == "include_relationship":
+                    a: Mapped[A] = relationship("A")
+                else:
+                    a: Mapped[A] = relationship()
+
+        self.assert_compile(
+            select(B).join(B.a),
+            "SELECT b.id, b.a_id FROM b JOIN a ON a.id = b.a_id",
+        )
+
+    def test_indirect_name_relationship_arg_override(self, decl_base):
+        """test #8759
+
+        in this test we assume a case where the type for the Mapped annnotation
+        a. has to be a different name than the actual class name and
+        b. cannot be imported outside of TYPE CHECKING.  user will then put
+        the real name inside of relationship().  we have to succeed even though
+        we can't resolve the annotation.
+
+        """
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+
+        if TYPE_CHECKING:
+            BNonExistent = B
+
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column()
+
+            bs: Mapped[List[BNonExistent]] = relationship("B")
+
+        self.assert_compile(
+            select(A).join(A.bs),
+            "SELECT a.id, a.data FROM a JOIN b ON a.id = b.a_id",
+        )
+
 
 class WriteOnlyRelationshipTest(_WriteOnlyRelationshipTest):
     def test_dynamic(self, decl_base):
index 72fd4d84e9ba18a0155ab2d41f5759fd60b45c3a..e3f5e59f4d703179aad5844085958f461e2ab27e 100644 (file)
@@ -1490,6 +1490,46 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         is_(a1.bs["foo"], b1)
 
+    @testing.combinations(
+        "include_relationship",
+        "no_relationship",
+        argnames="include_relationship",
+    )
+    @testing.combinations(
+        "direct_name", "indirect_name", argnames="indirect_name"
+    )
+    def test_indirect_name(
+        self, decl_base, include_relationship, indirect_name
+    ):
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Integer, primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+
+        B_ = B
+
+        class A(decl_base):
+            __tablename__ = "a"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column()
+
+            if indirect_name == "indirect_name":
+                if include_relationship == "include_relationship":
+                    bs: Mapped[List[B_]] = relationship("B")
+                else:
+                    bs: Mapped[List[B_]] = relationship()
+            else:
+                if include_relationship == "include_relationship":
+                    bs: Mapped[List[B]] = relationship("B")
+                else:
+                    bs: Mapped[List[B]] = relationship()
+
+        self.assert_compile(
+            select(A).join(A.bs),
+            "SELECT a.id, a.data FROM a JOIN b ON a.id = b.a_id",
+        )
+
 
 class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"