From 681055f9fb5230d344a67f47b0c60fc1a5804b3e Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 21 Feb 2023 10:34:01 -0500 Subject: [PATCH] apply a fixed locals w/ Mapped to all de-stringify Continued the fix for :ticket:`8853`, allowing the :class:`_orm.Mapped` name to be fully qualified regardless of whether or not ``from __annotations__ import future`` were present. This issue first fixed in 2.0.0b3 confirmed that this case worked via the test suite, however the test suite apparently was not testing the behavior for the name ``Mapped`` not being locally present at all; string resolution has been updated to ensure the ``Mapped`` symbol is locatable as applies to how the ORM uses these functions. Fixes: #8853 Fixes: #9335 Change-Id: Id82d09aee906165a4d77c7da6a0b4177dd675c10 --- doc/build/changelog/unreleased_20/8853.rst | 12 ++++ lib/sqlalchemy/orm/decl_base.py | 2 +- lib/sqlalchemy/orm/descriptor_props.py | 2 +- lib/sqlalchemy/orm/properties.py | 4 +- lib/sqlalchemy/orm/util.py | 65 ++++++++++++++++++- lib/sqlalchemy/util/typing.py | 35 ++++++++-- test/orm/declarative/test_abs_import_only.py | 47 ++++++++++++++ .../declarative/test_tm_future_annotations.py | 7 +- 8 files changed, 160 insertions(+), 14 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/8853.rst create mode 100644 test/orm/declarative/test_abs_import_only.py diff --git a/doc/build/changelog/unreleased_20/8853.rst b/doc/build/changelog/unreleased_20/8853.rst new file mode 100644 index 0000000000..a10ec43a66 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8853.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, orm + :tickets: 8853, 9335 + + Continued the fix for :ticket:`8853`, allowing the :class:`_orm.Mapped` + name to be fully qualified regardless of whether or not + ``from __annotations__ import future`` were present. This issue first fixed + in 2.0.0b3 confirmed that this case worked via the test suite, however the + test suite apparently was not testing the behavior for the name ``Mapped`` + not being locally present at all; string resolution has been updated to + ensure the ``Mapped`` symbol is locatable as applies to how the ORM uses + these functions. diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 29d7485961..d01aad4395 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -55,6 +55,7 @@ from .properties import MappedColumn from .util import _extract_mapped_subtype from .util import _is_mapped_annotation from .util import class_mapper +from .util import de_stringify_annotation from .. import event from .. import exc from .. import util @@ -64,7 +65,6 @@ from ..sql.schema import Column from ..sql.schema import Table from ..util import topological from ..util.typing import _AnnotationScanType -from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref from ..util.typing import is_literal from ..util.typing import Protocol diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index b65171c9d4..fd28830d9f 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -44,6 +44,7 @@ from .interfaces import _MapsColumns from .interfaces import MapperProperty from .interfaces import PropComparator from .util import _none_set +from .util import de_stringify_annotation from .. import event from .. import exc as sa_exc from .. import schema @@ -52,7 +53,6 @@ from .. import util from ..sql import expression from ..sql import operators from ..sql.elements import BindParameter -from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref from ..util.typing import is_pep593 from ..util.typing import typing_get_args diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index a5f34f3de3..4c07bad235 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -41,6 +41,8 @@ from .interfaces import MapperProperty from .interfaces import PropComparator from .interfaces import StrategizedProperty from .relationships import RelationshipProperty +from .util import de_stringify_annotation +from .util import de_stringify_union_elements from .. import exc as sa_exc from .. import ForeignKey from .. import log @@ -52,8 +54,6 @@ from ..sql.schema import Column from ..sql.schema import SchemaConst from ..sql.type_api import TypeEngine from ..util.typing import de_optionalize_union_types -from ..util.typing import de_stringify_annotation -from ..util.typing import de_stringify_union_elements from ..util.typing import is_fwd_ref from ..util.typing import is_optional_union from ..util.typing import is_pep593 diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index ad9ce2013d..7966f6cd9c 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -9,6 +9,7 @@ from __future__ import annotations import enum +import functools import re import types import typing @@ -46,6 +47,7 @@ from .base import attribute_str as attribute_str # noqa: F401 from .base import class_mapper as class_mapper from .base import InspectionAttr as InspectionAttr from .base import instance_str as instance_str # noqa: F401 +from .base import Mapped from .base import object_mapper as object_mapper from .base import object_state as object_state # noqa: F401 from .base import opt_manager_of_class @@ -79,10 +81,14 @@ from ..sql.elements import ColumnElement from ..sql.elements import KeyedColumnElement from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots -from ..util.typing import de_stringify_annotation -from ..util.typing import eval_name_only +from ..util.typing import de_stringify_annotation as _de_stringify_annotation +from ..util.typing import ( + de_stringify_union_elements as _de_stringify_union_elements, +) +from ..util.typing import eval_name_only as _eval_name_only from ..util.typing import is_origin_of_cls from ..util.typing import Literal +from ..util.typing import Protocol from ..util.typing import typing_get_origin if typing.TYPE_CHECKING: @@ -113,6 +119,7 @@ if typing.TYPE_CHECKING: from ..sql.selectable import Subquery from ..sql.visitors import anon_map from ..util.typing import _AnnotationScanType + from ..util.typing import ArgsTypeProcotol _T = TypeVar("_T", bound=Any) @@ -130,6 +137,58 @@ all_cascades = frozenset( ) +_de_stringify_partial = functools.partial( + functools.partial, locals_=util.immutabledict({"Mapped": Mapped}) +) + +# partial is practically useless as we have to write out the whole +# function and maintain the signature anyway + + +class _DeStringifyAnnotation(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: _AnnotationScanType, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + include_generic: bool = False, + ) -> Type[Any]: + ... + + +de_stringify_annotation = cast( + _DeStringifyAnnotation, _de_stringify_partial(_de_stringify_annotation) +) + + +class _DeStringifyUnionElements(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: ArgsTypeProcotol, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + ) -> Type[Any]: + ... + + +de_stringify_union_elements = cast( + _DeStringifyUnionElements, + _de_stringify_partial(_de_stringify_union_elements), +) + + +class _EvalNameOnly(Protocol): + def __call__(self, name: str, module_name: str) -> Any: + ... + + +eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only)) + + class CascadeOptions(FrozenSet[str]): """Keeps track of the options sent to :paramref:`.relationship.cascade`""" @@ -2271,7 +2330,7 @@ def _extract_mapped_subtype( cls, raw_annotation, originating_module, - _cleanup_mapped_str_annotation, + str_cleanup_fn=_cleanup_mapped_str_annotation, ) except _CleanupError as ce: raise sa_exc.ArgumentError( diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 9e6df0d359..24d8dd2dc1 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -18,6 +18,7 @@ from typing import Dict from typing import ForwardRef from typing import Generic from typing import Iterable +from typing import Mapping from typing import NewType from typing import NoReturn from typing import Optional @@ -123,6 +124,8 @@ def de_stringify_annotation( cls: Type[Any], annotation: _AnnotationScanType, originating_module: str, + locals_: Mapping[str, Any], + *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, include_generic: bool = False, ) -> Type[Any]: @@ -150,7 +153,9 @@ def de_stringify_annotation( if str_cleanup_fn: annotation = str_cleanup_fn(annotation, originating_module) - annotation = eval_expression(annotation, originating_module) + annotation = eval_expression( + annotation, originating_module, locals_=locals_ + ) if ( include_generic @@ -162,6 +167,7 @@ def de_stringify_annotation( cls, elem, originating_module, + locals_, str_cleanup_fn=str_cleanup_fn, include_generic=include_generic, ) @@ -183,7 +189,12 @@ def _copy_generic_annotation_with( return annotation.__origin__[elements] # type: ignore -def eval_expression(expression: str, module_name: str) -> Any: +def eval_expression( + expression: str, + module_name: str, + *, + locals_: Optional[Mapping[str, Any]] = None, +) -> Any: try: base_globals: Dict[str, Any] = sys.modules[module_name].__dict__ except KeyError as ke: @@ -191,8 +202,9 @@ def eval_expression(expression: str, module_name: str) -> Any: f"Module {module_name} isn't present in sys.modules; can't " f"evaluate expression {expression}" ) from ke + try: - annotation = eval(expression, base_globals, None) + annotation = eval(expression, base_globals, locals_) except Exception as err: raise NameError( f"Could not de-stringify annotation {expression!r}" @@ -201,9 +213,14 @@ def eval_expression(expression: str, module_name: str) -> Any: return annotation -def eval_name_only(name: str, module_name: str) -> Any: +def eval_name_only( + name: str, + module_name: str, + *, + locals_: Optional[Mapping[str, Any]] = None, +) -> Any: if "." in name: - return eval_expression(name, module_name) + return eval_expression(name, module_name, locals_=locals_) try: base_globals: Dict[str, Any] = sys.modules[module_name].__dict__ @@ -237,12 +254,18 @@ def de_stringify_union_elements( cls: Type[Any], annotation: ArgsTypeProcotol, originating_module: str, + locals_: Mapping[str, Any], + *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, ) -> Type[Any]: return make_union_type( *[ de_stringify_annotation( - cls, anno, originating_module, str_cleanup_fn + cls, + anno, + originating_module, + {}, + str_cleanup_fn=str_cleanup_fn, ) for anno in annotation.__args__ ] diff --git a/test/orm/declarative/test_abs_import_only.py b/test/orm/declarative/test_abs_import_only.py new file mode 100644 index 0000000000..e700b4cc2d --- /dev/null +++ b/test/orm/declarative/test_abs_import_only.py @@ -0,0 +1,47 @@ +""" +this file tests that absolute imports can be used in declarative +mappings while guaranteeing that the Mapped name is not locally present + +""" + +from __future__ import annotations + +import sqlalchemy +from sqlalchemy import orm +import sqlalchemy.orm +import sqlalchemy.testing +import sqlalchemy.testing.fixtures + +try: + x = Mapped # type: ignore +except NameError: + pass +else: + raise Exception("Mapped name **must not be imported in this file**") + + +class MappedColumnTest( + sqlalchemy.testing.fixtures.TestBase, sqlalchemy.testing.AssertsCompiledSQL +): + __dialect__ = "default" + + def test_fully_qualified_mapped_name(self, decl_base): + """test #8853 *again*, as reported in #9335 this failed to be fixed""" + + class Foo(decl_base): + __tablename__ = "foo" + + id: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column( + primary_key=True + ) + + data: sqlalchemy.orm.Mapped[int] = sqlalchemy.orm.mapped_column() + + data2: sqlalchemy.orm.Mapped[int] + + data3: orm.Mapped[int] + + self.assert_compile( + sqlalchemy.select(Foo), + "SELECT foo.id, foo.data, foo.data2, foo.data3 FROM foo", + ) diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index b66d67a77f..1677cdbb9b 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -42,7 +42,12 @@ class M3: class MappedColumnTest(_MappedColumnTest): def test_fully_qualified_mapped_name(self, decl_base): - """test #8853, regression caused by #8759 ;)""" # noqa: E501 + """test #8853, regression caused by #8759 ;) + + + See same test in test_abs_import_only + + """ class Foo(decl_base): __tablename__ = "foo" -- 2.47.2