From: Mike Bayer Date: Mon, 31 Oct 2022 13:51:51 +0000 (-0400) Subject: evaluate types in terms of the class in which they appear X-Git-Tag: rel_2_0_0b3~15 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=66f3533de86506327c753c1ea80b121692535745;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git evaluate types in terms of the class in which they appear Fixed issues within the declarative typing resolver (i.e. which resolves ``ForwardRef`` objects) where types that were declared for columns in one particular source file would raise ``NameError`` when the ultimate mapped class were in another source file. The types are now resolved in terms of the module for each class in which the types are used. Fixes: #8742 Change-Id: I236f94484ea79d47392a6201e671eeb89c305fd8 --- diff --git a/doc/build/changelog/unreleased_20/8742.rst b/doc/build/changelog/unreleased_20/8742.rst new file mode 100644 index 0000000000..71e4583345 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8742.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm declarative + :tickets: 8742 + + Fixed issues within the declarative typing resolver (i.e. which resolves + ``ForwardRef`` objects) where types that were declared for columns in one + particular source file would raise ``NameError`` when the ultimate mapped + class were in another source file. The types are now resolved in terms + of the module for each class in which the types are used. diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 4e02e589b9..c233298b9f 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -19,6 +19,7 @@ from typing import Dict from typing import Iterable from typing import List from typing import Mapping +from typing import NamedTuple from typing import NoReturn from typing import Optional from typing import Sequence @@ -70,6 +71,7 @@ from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _ClassDict from ._typing import _RegistryType + from .base import Mapped from .decl_api import declared_attr from .instrumentation import ClassManager from ..sql.elements import NamedColumn @@ -397,6 +399,15 @@ class _ImperativeMapperConfig(_MapperConfig): self.inherits = inherits +class _CollectedAnnotation(NamedTuple): + raw_annotation: _AnnotationScanType + mapped_container: Optional[Type[Mapped[Any]]] + extracted_mapped_annotation: Union[Type[Any], str] + is_dataclass: bool + attr_value: Any + originating_module: str + + class _ClassScanMapperConfig(_MapperConfig): __slots__ = ( "registry", @@ -420,7 +431,7 @@ class _ClassScanMapperConfig(_MapperConfig): registry: _RegistryType clsdict_view: _ClassDict - collected_annotations: Dict[str, Tuple[Any, Any, Any, bool, Any]] + collected_annotations: Dict[str, _CollectedAnnotation] collected_attributes: Dict[str, Any] local_table: Optional[FromClause] persist_selectable: Optional[FromClause] @@ -740,6 +751,7 @@ class _ClassScanMapperConfig(_MapperConfig): local_attributes_for_class, attribute_is_overridden, fixed_table, + base, ) else: locally_collected_columns = {} @@ -913,10 +925,11 @@ class _ClassScanMapperConfig(_MapperConfig): self._collect_annotation( name, obj._collect_return_annotation(), + base, True, obj, ) - elif _is_mapped_annotation(annotation, cls): + elif _is_mapped_annotation(annotation, cls, base): # Mapped annotation without any object. # product_column_copies should have handled this. # if future support for other MapperProperty, @@ -948,15 +961,17 @@ class _ClassScanMapperConfig(_MapperConfig): obj = obj.fget() collected_attributes[name] = obj - self._collect_annotation(name, annotation, False, obj) + self._collect_annotation( + name, annotation, base, False, obj + ) else: generated_obj = self._collect_annotation( - name, annotation, None, obj + name, annotation, base, None, obj ) if ( obj is None and not fixed_table - and _is_mapped_annotation(annotation, cls) + and _is_mapped_annotation(annotation, cls, base) ): collected_attributes[name] = ( generated_obj @@ -1005,6 +1020,7 @@ class _ClassScanMapperConfig(_MapperConfig): mapped_anno, is_dc, attr_value, + originating_module, ) in self.collected_annotations.items() ) ] @@ -1058,6 +1074,7 @@ class _ClassScanMapperConfig(_MapperConfig): self, name: str, raw_annotation: _AnnotationScanType, + originating_class: Type[Any], expect_mapped: Optional[bool], attr_value: Any, ) -> Any: @@ -1088,6 +1105,7 @@ class _ClassScanMapperConfig(_MapperConfig): extracted = _extract_mapped_subtype( raw_annotation, self.cls, + originating_class.__module__, name, type(attr_value), required=False, @@ -1109,12 +1127,13 @@ class _ClassScanMapperConfig(_MapperConfig): if isinstance(elem, _IntrospectsAnnotations): attr_value = elem.found_in_pep593_annotated() - self.collected_annotations[name] = ( + self.collected_annotations[name] = _CollectedAnnotation( raw_annotation, mapped_container, extracted_mapped_annotation, is_dataclass, attr_value, + originating_class.__module__, ) return attr_value @@ -1135,6 +1154,7 @@ class _ClassScanMapperConfig(_MapperConfig): ], attribute_is_overridden: Callable[[str, Any], bool], fixed_table: bool, + originating_class: Type[Any], ) -> Dict[str, Union[Column[Any], MappedColumn[Any]]]: cls = self.cls dict_ = self.clsdict_view @@ -1146,9 +1166,11 @@ class _ClassScanMapperConfig(_MapperConfig): if ( not fixed_table and obj is None - and _is_mapped_annotation(annotation, cls) + and _is_mapped_annotation(annotation, cls, originating_class) ): - obj = self._collect_annotation(name, annotation, True, obj) + obj = self._collect_annotation( + name, annotation, originating_class, True, obj + ) if obj is None: obj = MappedColumn() @@ -1164,7 +1186,9 @@ class _ClassScanMapperConfig(_MapperConfig): # either (issue #8718) continue - obj = self._collect_annotation(name, annotation, True, obj) + obj = self._collect_annotation( + name, annotation, originating_class, True, obj + ) if name not in dict_ and not ( "__table__" in dict_ @@ -1282,8 +1306,9 @@ class _ClassScanMapperConfig(_MapperConfig): extracted_mapped_annotation, is_dataclass, attr_value, + originating_module, ) = self.collected_annotations.get( - k, (None, None, None, False, None) + k, (None, None, None, False, None, None) ) # issue #8692 - don't do any annotation interpretation if @@ -1295,6 +1320,7 @@ class _ClassScanMapperConfig(_MapperConfig): value.declarative_scan( self.registry, cls, + originating_module, k, mapped_container, annotation, diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index a15cd86f43..84d15360d2 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -332,6 +332,7 @@ class CompositeProperty( self, registry: _RegistryType, cls: Type[Any], + originating_module: Optional[str], key: str, mapped_container: Optional[Type[Mapped[Any]]], annotation: Optional[_AnnotationScanType], @@ -365,7 +366,7 @@ class CompositeProperty( self.composite_class = argument if is_dataclass(self.composite_class): - self._setup_for_dataclass(registry, cls, key) + self._setup_for_dataclass(registry, cls, originating_module, key) else: for attr in self.attrs: if ( @@ -408,7 +409,11 @@ class CompositeProperty( @util.preload_module("sqlalchemy.orm.properties") @util.preload_module("sqlalchemy.orm.decl_base") def _setup_for_dataclass( - self, registry: _RegistryType, cls: Type[Any], key: str + self, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, ) -> None: MappedColumn = util.preloaded.orm_properties.MappedColumn @@ -432,7 +437,12 @@ class CompositeProperty( if isinstance(attr, MappedColumn): attr.declarative_scan_for_composite( - registry, cls, key, param.name, param.annotation + registry, + cls, + originating_module, + key, + param.name, + param.annotation, ) elif isinstance(attr, schema.Column): decl_base._undefer_column_name(param.name, attr) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 1747bfd9b2..e61e82126a 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -157,6 +157,7 @@ class _IntrospectsAnnotations: self, registry: RegistryType, cls: Type[Any], + originating_module: Optional[str], key: str, mapped_container: Optional[Type[Mapped[Any]]], annotation: Optional[_AnnotationScanType], @@ -215,11 +216,12 @@ class _AttributeOptions(NamedTuple): def _get_arguments_for_make_dataclass( cls, key: str, - annotation: Type[Any], + annotation: _AnnotationScanType, mapped_container: Optional[Any], elem: _T, ) -> Union[ - Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]] + Tuple[str, _AnnotationScanType], + Tuple[str, _AnnotationScanType, dataclasses.Field[Any]], ]: """given attribute key, annotation, and value from a class, return the argument tuple we would pass to dataclasses.make_dataclass() diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index c67da3905f..0cbd3f7135 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -197,6 +197,7 @@ class ColumnProperty( self, registry: _RegistryType, cls: Type[Any], + originating_module: Optional[str], key: str, mapped_container: Optional[Type[Mapped[Any]]], annotation: Optional[_AnnotationScanType], @@ -637,6 +638,7 @@ class MappedColumn( self, registry: _RegistryType, cls: Type[Any], + originating_module: Optional[str], key: str, mapped_container: Optional[Type[Mapped[Any]]], annotation: Optional[_AnnotationScanType], @@ -658,7 +660,7 @@ class MappedColumn( return self._init_column_for_annotation( - cls, registry, extracted_mapped_annotation + cls, registry, extracted_mapped_annotation, originating_module ) @util.preload_module("sqlalchemy.orm.decl_base") @@ -666,27 +668,37 @@ class MappedColumn( self, registry: _RegistryType, cls: Type[Any], + originating_module: Optional[str], key: str, param_name: str, param_annotation: _AnnotationScanType, ) -> None: decl_base = util.preloaded.orm_decl_base decl_base._undefer_column_name(param_name, self.column) - self._init_column_for_annotation(cls, registry, param_annotation) + self._init_column_for_annotation( + cls, registry, param_annotation, originating_module + ) def _init_column_for_annotation( self, cls: Type[Any], registry: _RegistryType, argument: _AnnotationScanType, + originating_module: Optional[str], ) -> None: sqltype = self.column.type if is_fwd_ref(argument): - argument = de_stringify_annotation(cls, argument) + assert originating_module is not None + argument = de_stringify_annotation( + cls, argument, originating_module + ) if is_union(argument): - argument = de_stringify_union_elements(cls, argument) + assert originating_module is not None + argument = de_stringify_union_elements( + cls, argument, originating_module + ) nullable = is_optional_union(argument) diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 276199da27..6d388a630b 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1712,6 +1712,7 @@ class RelationshipProperty( self, registry: _RegistryType, cls: Type[Any], + originating_module: Optional[str], key: str, mapped_container: Optional[Type[Mapped[Any]]], annotation: Optional[_AnnotationScanType], diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 3302feb70a..e64c936363 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -2020,10 +2020,14 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any: def _is_mapped_annotation( - raw_annotation: _AnnotationScanType, cls: Type[Any] + raw_annotation: _AnnotationScanType, + cls: Type[Any], + originating_cls: Type[Any], ) -> bool: try: - annotated = de_stringify_annotation(cls, raw_annotation) + annotated = de_stringify_annotation( + cls, raw_annotation, originating_cls.__module__ + ) except NameError: return False else: @@ -2065,6 +2069,7 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str: def _extract_mapped_subtype( raw_annotation: Optional[_AnnotationScanType], cls: type, + originating_module: str, key: str, attr_cls: Type[Any], required: bool, @@ -2091,7 +2096,10 @@ def _extract_mapped_subtype( try: annotated = de_stringify_annotation( - cls, raw_annotation, _cleanup_mapped_str_annotation + cls, + raw_annotation, + originating_module, + _cleanup_mapped_str_annotation, ) except NameError as ne: if raiseerr and "Mapped[" in raw_annotation: # type: ignore diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 1d93444476..b5d918a74e 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -84,6 +84,7 @@ _LiteralStar = Literal["*"] def de_stringify_annotation( cls: Type[Any], annotation: _AnnotationScanType, + originating_module: str, str_cleanup_fn: Optional[Callable[[str], str]] = None, ) -> Type[Any]: """Resolve annotations that may be string based into real objects. @@ -109,14 +110,14 @@ def de_stringify_annotation( if isinstance(annotation, str): if str_cleanup_fn: annotation = str_cleanup_fn(annotation) - base_globals: "Dict[str, Any]" = getattr( - sys.modules.get(cls.__module__, None), "__dict__", {} + sys.modules.get(originating_module, None), "__dict__", {} ) try: annotation = eval(annotation, base_globals, None) except NameError as err: + # breakpoint() raise NameError( f"Could not de-stringify annotation {annotation}" ) from err @@ -126,11 +127,14 @@ def de_stringify_annotation( def de_stringify_union_elements( cls: Type[Any], annotation: _AnnotationScanType, + originating_module: str, str_cleanup_fn: Optional[Callable[[str], str]] = None, ) -> Type[Any]: return make_union_type( *[ - de_stringify_annotation(cls, anno, str_cleanup_fn) + de_stringify_annotation( + cls, anno, originating_module, str_cleanup_fn + ) for anno in annotation.__args__ # type: ignore ] ) diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index 24d666508a..03e41e7cfa 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -6,6 +6,7 @@ from typing import Optional from typing import Set from typing import TypeVar from typing import Union +import uuid from sqlalchemy import Column from sqlalchemy import exc @@ -16,6 +17,7 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing +from sqlalchemy import Uuid from sqlalchemy.orm import attribute_keyed_dict from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DynamicMapped @@ -116,6 +118,26 @@ class MappedColumnTest(_MappedColumnTest): is_(optional_col.type, our_type) is_true(optional_col.nullable) + def test_typ_not_in_cls_namespace(self, decl_base): + """test #8742. + + This tests that when types are resolved, they use the ``__module__`` + of they class they are used within, not the mapped class. + + """ + + class Mixin: + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[uuid.UUID] + + class MyClass(Mixin, decl_base): + # basically no type will be resolvable here + __module__ = "some.module" + __tablename__ = "mytable" + + is_(MyClass.id.expression.type._type_affinity, Integer) + is_(MyClass.data.expression.type._type_affinity, Uuid) + class MappedOneArg(KeyFuncDict[str, _R]): pass