]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
apply a fixed locals w/ Mapped to all de-stringify
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Feb 2023 15:34:01 +0000 (10:34 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Feb 2023 03:20:11 +0000 (22:20 -0500)
Continued the fix for :ticket:`8853`, allowing the :class:`_orm.Mapped`
name to be fully qualified regardless of whether or not
``from __annotations__ import future`` were present. This issue first fixed
in 2.0.0b3 confirmed that this case worked via the test suite, however the
test suite apparently was not testing the behavior for the name ``Mapped``
not being locally present at all; string resolution has been updated to
ensure the ``Mapped`` symbol is locatable as applies to how the ORM uses
these functions.

Fixes: #8853
Fixes: #9335
Change-Id: Id82d09aee906165a4d77c7da6a0b4177dd675c10

doc/build/changelog/unreleased_20/8853.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_abs_import_only.py [new file with mode: 0644]
test/orm/declarative/test_tm_future_annotations.py

diff --git a/doc/build/changelog/unreleased_20/8853.rst b/doc/build/changelog/unreleased_20/8853.rst
new file mode 100644 (file)
index 0000000..a10ec43
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 8853, 9335
+
+    Continued the fix for :ticket:`8853`, allowing the :class:`_orm.Mapped`
+    name to be fully qualified regardless of whether or not
+    ``from __annotations__ import future`` were present. This issue first fixed
+    in 2.0.0b3 confirmed that this case worked via the test suite, however the
+    test suite apparently was not testing the behavior for the name ``Mapped``
+    not being locally present at all; string resolution has been updated to
+    ensure the ``Mapped`` symbol is locatable as applies to how the ORM uses
+    these functions.
index 29d74859618e16b4338f1aaabacacee23a7397ba..d01aad439533e6399733fdc1c44d0a9670637863 100644 (file)
@@ -55,6 +55,7 @@ from .properties import MappedColumn
 from .util import _extract_mapped_subtype
 from .util import _is_mapped_annotation
 from .util import class_mapper
+from .util import de_stringify_annotation
 from .. import event
 from .. import exc
 from .. import util
@@ -64,7 +65,6 @@ from ..sql.schema import Column
 from ..sql.schema import Table
 from ..util import topological
 from ..util.typing import _AnnotationScanType
-from ..util.typing import de_stringify_annotation
 from ..util.typing import is_fwd_ref
 from ..util.typing import is_literal
 from ..util.typing import Protocol
index b65171c9d4775eaf2693b1e31f3b834d17edd516..fd28830d9fb4b9ce68ce75b770bdf11f43dbe16a 100644 (file)
@@ -44,6 +44,7 @@ from .interfaces import _MapsColumns
 from .interfaces import MapperProperty
 from .interfaces import PropComparator
 from .util import _none_set
+from .util import de_stringify_annotation
 from .. import event
 from .. import exc as sa_exc
 from .. import schema
@@ -52,7 +53,6 @@ from .. import util
 from ..sql import expression
 from ..sql import operators
 from ..sql.elements import BindParameter
-from ..util.typing import de_stringify_annotation
 from ..util.typing import is_fwd_ref
 from ..util.typing import is_pep593
 from ..util.typing import typing_get_args
index a5f34f3de38fa56e3d60f7dbe73431c386f08f9f..4c07bad235f6aabc9c75456835635a7a99934165 100644 (file)
@@ -41,6 +41,8 @@ from .interfaces import MapperProperty
 from .interfaces import PropComparator
 from .interfaces import StrategizedProperty
 from .relationships import RelationshipProperty
+from .util import de_stringify_annotation
+from .util import de_stringify_union_elements
 from .. import exc as sa_exc
 from .. import ForeignKey
 from .. import log
@@ -52,8 +54,6 @@ from ..sql.schema import Column
 from ..sql.schema import SchemaConst
 from ..sql.type_api import TypeEngine
 from ..util.typing import de_optionalize_union_types
-from ..util.typing import de_stringify_annotation
-from ..util.typing import de_stringify_union_elements
 from ..util.typing import is_fwd_ref
 from ..util.typing import is_optional_union
 from ..util.typing import is_pep593
index ad9ce2013da40f76746f63e59819f2da38b7b40c..7966f6cd9c6c26fee6bbad9432d4769c63b43729 100644 (file)
@@ -9,6 +9,7 @@
 from __future__ import annotations
 
 import enum
+import functools
 import re
 import types
 import typing
@@ -46,6 +47,7 @@ from .base import attribute_str as attribute_str  # noqa: F401
 from .base import class_mapper as class_mapper
 from .base import InspectionAttr as InspectionAttr
 from .base import instance_str as instance_str  # noqa: F401
+from .base import Mapped
 from .base import object_mapper as object_mapper
 from .base import object_state as object_state  # noqa: F401
 from .base import opt_manager_of_class
@@ -79,10 +81,14 @@ from ..sql.elements import ColumnElement
 from ..sql.elements import KeyedColumnElement
 from ..sql.selectable import FromClause
 from ..util.langhelpers import MemoizedSlots
-from ..util.typing import de_stringify_annotation
-from ..util.typing import eval_name_only
+from ..util.typing import de_stringify_annotation as _de_stringify_annotation
+from ..util.typing import (
+    de_stringify_union_elements as _de_stringify_union_elements,
+)
+from ..util.typing import eval_name_only as _eval_name_only
 from ..util.typing import is_origin_of_cls
 from ..util.typing import Literal
+from ..util.typing import Protocol
 from ..util.typing import typing_get_origin
 
 if typing.TYPE_CHECKING:
@@ -113,6 +119,7 @@ if typing.TYPE_CHECKING:
     from ..sql.selectable import Subquery
     from ..sql.visitors import anon_map
     from ..util.typing import _AnnotationScanType
+    from ..util.typing import ArgsTypeProcotol
 
 _T = TypeVar("_T", bound=Any)
 
@@ -130,6 +137,58 @@ all_cascades = frozenset(
 )
 
 
+_de_stringify_partial = functools.partial(
+    functools.partial, locals_=util.immutabledict({"Mapped": Mapped})
+)
+
+# partial is practically useless as we have to write out the whole
+# function and maintain the signature anyway
+
+
+class _DeStringifyAnnotation(Protocol):
+    def __call__(
+        self,
+        cls: Type[Any],
+        annotation: _AnnotationScanType,
+        originating_module: str,
+        *,
+        str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
+        include_generic: bool = False,
+    ) -> Type[Any]:
+        ...
+
+
+de_stringify_annotation = cast(
+    _DeStringifyAnnotation, _de_stringify_partial(_de_stringify_annotation)
+)
+
+
+class _DeStringifyUnionElements(Protocol):
+    def __call__(
+        self,
+        cls: Type[Any],
+        annotation: ArgsTypeProcotol,
+        originating_module: str,
+        *,
+        str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
+    ) -> Type[Any]:
+        ...
+
+
+de_stringify_union_elements = cast(
+    _DeStringifyUnionElements,
+    _de_stringify_partial(_de_stringify_union_elements),
+)
+
+
+class _EvalNameOnly(Protocol):
+    def __call__(self, name: str, module_name: str) -> Any:
+        ...
+
+
+eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only))
+
+
 class CascadeOptions(FrozenSet[str]):
     """Keeps track of the options sent to
     :paramref:`.relationship.cascade`"""
@@ -2271,7 +2330,7 @@ def _extract_mapped_subtype(
             cls,
             raw_annotation,
             originating_module,
-            _cleanup_mapped_str_annotation,
+            str_cleanup_fn=_cleanup_mapped_str_annotation,
         )
     except _CleanupError as ce:
         raise sa_exc.ArgumentError(
index 9e6df0d359dcc20fdcb8f91a45afb3842a5dd914..24d8dd2dc11e1d8c921fb49718a67b096c252566 100644 (file)
@@ -18,6 +18,7 @@ from typing import Dict
 from typing import ForwardRef
 from typing import Generic
 from typing import Iterable
+from typing import Mapping
 from typing import NewType
 from typing import NoReturn
 from typing import Optional
@@ -123,6 +124,8 @@ def de_stringify_annotation(
     cls: Type[Any],
     annotation: _AnnotationScanType,
     originating_module: str,
+    locals_: Mapping[str, Any],
+    *,
     str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
     include_generic: bool = False,
 ) -> Type[Any]:
@@ -150,7 +153,9 @@ def de_stringify_annotation(
         if str_cleanup_fn:
             annotation = str_cleanup_fn(annotation, originating_module)
 
-        annotation = eval_expression(annotation, originating_module)
+        annotation = eval_expression(
+            annotation, originating_module, locals_=locals_
+        )
 
     if (
         include_generic
@@ -162,6 +167,7 @@ def de_stringify_annotation(
                 cls,
                 elem,
                 originating_module,
+                locals_,
                 str_cleanup_fn=str_cleanup_fn,
                 include_generic=include_generic,
             )
@@ -183,7 +189,12 @@ def _copy_generic_annotation_with(
         return annotation.__origin__[elements]  # type: ignore
 
 
-def eval_expression(expression: str, module_name: str) -> Any:
+def eval_expression(
+    expression: str,
+    module_name: str,
+    *,
+    locals_: Optional[Mapping[str, Any]] = None,
+) -> Any:
     try:
         base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
     except KeyError as ke:
@@ -191,8 +202,9 @@ def eval_expression(expression: str, module_name: str) -> Any:
             f"Module {module_name} isn't present in sys.modules; can't "
             f"evaluate expression {expression}"
         ) from ke
+
     try:
-        annotation = eval(expression, base_globals, None)
+        annotation = eval(expression, base_globals, locals_)
     except Exception as err:
         raise NameError(
             f"Could not de-stringify annotation {expression!r}"
@@ -201,9 +213,14 @@ def eval_expression(expression: str, module_name: str) -> Any:
         return annotation
 
 
-def eval_name_only(name: str, module_name: str) -> Any:
+def eval_name_only(
+    name: str,
+    module_name: str,
+    *,
+    locals_: Optional[Mapping[str, Any]] = None,
+) -> Any:
     if "." in name:
-        return eval_expression(name, module_name)
+        return eval_expression(name, module_name, locals_=locals_)
 
     try:
         base_globals: Dict[str, Any] = sys.modules[module_name].__dict__
@@ -237,12 +254,18 @@ def de_stringify_union_elements(
     cls: Type[Any],
     annotation: ArgsTypeProcotol,
     originating_module: str,
+    locals_: Mapping[str, Any],
+    *,
     str_cleanup_fn: Optional[Callable[[str, str], str]] = None,
 ) -> Type[Any]:
     return make_union_type(
         *[
             de_stringify_annotation(
-                cls, anno, originating_module, str_cleanup_fn
+                cls,
+                anno,
+                originating_module,
+                {},
+                str_cleanup_fn=str_cleanup_fn,
             )
             for anno in annotation.__args__
         ]
diff --git a/test/orm/declarative/test_abs_import_only.py b/test/orm/declarative/test_abs_import_only.py
new file mode 100644 (file)
index 0000000..e700b4c
--- /dev/null
@@ -0,0 +1,47 @@
+"""
+this file tests that absolute imports can be used in declarative
+mappings while guaranteeing that the Mapped name is not locally present
+
+"""
+
+from __future__ import annotations
+
+import sqlalchemy
+from sqlalchemy import orm
+import sqlalchemy.orm
+import sqlalchemy.testing
+import sqlalchemy.testing.fixtures
+
+try:
+    x = Mapped  # type: ignore
+except NameError:
+    pass
+else:
+    raise Exception("Mapped name **must not be imported in this file**")
+
+
+class MappedColumnTest(
+    sqlalchemy.testing.fixtures.TestBase, sqlalchemy.testing.AssertsCompiledSQL
+):
+    __dialect__ = "default"
+
+    def test_fully_qualified_mapped_name(self, decl_base):
+        """test #8853 *again*, as reported in #9335 this failed to be fixed"""
+
+        class Foo(decl_base):
+            __tablename__ = "foo"
+
+            id: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column(
+                primary_key=True
+            )
+
+            data: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column()
+
+            data2: sqlalchemy.orm.Mapped[int]
+
+            data3: orm.Mapped[int]
+
+        self.assert_compile(
+            sqlalchemy.select(Foo),
+            "SELECT foo.id, foo.data, foo.data2, foo.data3 FROM foo",
+        )
index b66d67a77ffc0aaaf3efa5ee09e774175d3035b2..1677cdbb9bde2cd1bd8617df111a79c3c5a8a101 100644 (file)
@@ -42,7 +42,12 @@ class M3:
 
 class MappedColumnTest(_MappedColumnTest):
     def test_fully_qualified_mapped_name(self, decl_base):
-        """test #8853, regression caused by #8759 ;)"""  # noqa: E501
+        """test #8853, regression caused by #8759 ;)
+
+
+        See same test in test_abs_import_only
+
+        """
 
         class Foo(decl_base):
             __tablename__ = "foo"