]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
evaluate types in terms of the class in which they appear
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 31 Oct 2022 13:51:51 +0000 (09:51 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 31 Oct 2022 16:56:21 +0000 (12:56 -0400)
Fixed issues within the declarative typing resolver (i.e. which resolves
``ForwardRef`` objects) where types that were declared for columns in one
particular source file would raise ``NameError`` when the ultimate mapped
class were in another source file.  The types are now resolved in terms
of the module for each class in which the types are used.

Fixes: #8742
Change-Id: I236f94484ea79d47392a6201e671eeb89c305fd8

doc/build/changelog/unreleased_20/8742.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/relationships.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations.py

diff --git a/doc/build/changelog/unreleased_20/8742.rst b/doc/build/changelog/unreleased_20/8742.rst
new file mode 100644 (file)
index 0000000..71e4583
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm declarative
+    :tickets: 8742
+
+    Fixed issues within the declarative typing resolver (i.e. which resolves
+    ``ForwardRef`` objects) where types that were declared for columns in one
+    particular source file would raise ``NameError`` when the ultimate mapped
+    class were in another source file.  The types are now resolved in terms
+    of the module for each class in which the types are used.
index 4e02e589b9cc2e7144e6e736d9f9f4ba487b6684..c233298b9fc5985ed43fa8b62121ec83e75f0062 100644 (file)
@@ -19,6 +19,7 @@ from typing import Dict
 from typing import Iterable
 from typing import List
 from typing import Mapping
+from typing import NamedTuple
 from typing import NoReturn
 from typing import Optional
 from typing import Sequence
@@ -70,6 +71,7 @@ from ..util.typing import typing_get_args
 if TYPE_CHECKING:
     from ._typing import _ClassDict
     from ._typing import _RegistryType
+    from .base import Mapped
     from .decl_api import declared_attr
     from .instrumentation import ClassManager
     from ..sql.elements import NamedColumn
@@ -397,6 +399,15 @@ class _ImperativeMapperConfig(_MapperConfig):
         self.inherits = inherits
 
 
+class _CollectedAnnotation(NamedTuple):
+    raw_annotation: _AnnotationScanType
+    mapped_container: Optional[Type[Mapped[Any]]]
+    extracted_mapped_annotation: Union[Type[Any], str]
+    is_dataclass: bool
+    attr_value: Any
+    originating_module: str
+
+
 class _ClassScanMapperConfig(_MapperConfig):
     __slots__ = (
         "registry",
@@ -420,7 +431,7 @@ class _ClassScanMapperConfig(_MapperConfig):
 
     registry: _RegistryType
     clsdict_view: _ClassDict
-    collected_annotations: Dict[str, Tuple[Any, Any, Any, bool, Any]]
+    collected_annotations: Dict[str, _CollectedAnnotation]
     collected_attributes: Dict[str, Any]
     local_table: Optional[FromClause]
     persist_selectable: Optional[FromClause]
@@ -740,6 +751,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                     local_attributes_for_class,
                     attribute_is_overridden,
                     fixed_table,
+                    base,
                 )
             else:
                 locally_collected_columns = {}
@@ -913,10 +925,11 @@ class _ClassScanMapperConfig(_MapperConfig):
                         self._collect_annotation(
                             name,
                             obj._collect_return_annotation(),
+                            base,
                             True,
                             obj,
                         )
-                    elif _is_mapped_annotation(annotation, cls):
+                    elif _is_mapped_annotation(annotation, cls, base):
                         # Mapped annotation without any object.
                         # product_column_copies should have handled this.
                         # if future support for other MapperProperty,
@@ -948,15 +961,17 @@ class _ClassScanMapperConfig(_MapperConfig):
                         obj = obj.fget()
 
                     collected_attributes[name] = obj
-                    self._collect_annotation(name, annotation, False, obj)
+                    self._collect_annotation(
+                        name, annotation, base, False, obj
+                    )
                 else:
                     generated_obj = self._collect_annotation(
-                        name, annotation, None, obj
+                        name, annotation, base, None, obj
                     )
                     if (
                         obj is None
                         and not fixed_table
-                        and _is_mapped_annotation(annotation, cls)
+                        and _is_mapped_annotation(annotation, cls, base)
                     ):
                         collected_attributes[name] = (
                             generated_obj
@@ -1005,6 +1020,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                     mapped_anno,
                     is_dc,
                     attr_value,
+                    originating_module,
                 ) in self.collected_annotations.items()
             )
         ]
@@ -1058,6 +1074,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         self,
         name: str,
         raw_annotation: _AnnotationScanType,
+        originating_class: Type[Any],
         expect_mapped: Optional[bool],
         attr_value: Any,
     ) -> Any:
@@ -1088,6 +1105,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         extracted = _extract_mapped_subtype(
             raw_annotation,
             self.cls,
+            originating_class.__module__,
             name,
             type(attr_value),
             required=False,
@@ -1109,12 +1127,13 @@ class _ClassScanMapperConfig(_MapperConfig):
                 if isinstance(elem, _IntrospectsAnnotations):
                     attr_value = elem.found_in_pep593_annotated()
 
-        self.collected_annotations[name] = (
+        self.collected_annotations[name] = _CollectedAnnotation(
             raw_annotation,
             mapped_container,
             extracted_mapped_annotation,
             is_dataclass,
             attr_value,
+            originating_class.__module__,
         )
         return attr_value
 
@@ -1135,6 +1154,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         ],
         attribute_is_overridden: Callable[[str, Any], bool],
         fixed_table: bool,
+        originating_class: Type[Any],
     ) -> Dict[str, Union[Column[Any], MappedColumn[Any]]]:
         cls = self.cls
         dict_ = self.clsdict_view
@@ -1146,9 +1166,11 @@ class _ClassScanMapperConfig(_MapperConfig):
             if (
                 not fixed_table
                 and obj is None
-                and _is_mapped_annotation(annotation, cls)
+                and _is_mapped_annotation(annotation, cls, originating_class)
             ):
-                obj = self._collect_annotation(name, annotation, True, obj)
+                obj = self._collect_annotation(
+                    name, annotation, originating_class, True, obj
+                )
                 if obj is None:
                     obj = MappedColumn()
 
@@ -1164,7 +1186,9 @@ class _ClassScanMapperConfig(_MapperConfig):
                     # either (issue #8718)
                     continue
 
-                obj = self._collect_annotation(name, annotation, True, obj)
+                obj = self._collect_annotation(
+                    name, annotation, originating_class, True, obj
+                )
 
                 if name not in dict_ and not (
                     "__table__" in dict_
@@ -1282,8 +1306,9 @@ class _ClassScanMapperConfig(_MapperConfig):
                         extracted_mapped_annotation,
                         is_dataclass,
                         attr_value,
+                        originating_module,
                     ) = self.collected_annotations.get(
-                        k, (None, None, None, False, None)
+                        k, (None, None, None, False, None, None)
                     )
 
                     # issue #8692 - don't do any annotation interpretation if
@@ -1295,6 +1320,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                         value.declarative_scan(
                             self.registry,
                             cls,
+                            originating_module,
                             k,
                             mapped_container,
                             annotation,
index a15cd86f4328a071aee8a862e14d954a8d50c994..84d15360d2e877dfcbf19682db9cb053f4730521 100644 (file)
@@ -332,6 +332,7 @@ class CompositeProperty(
         self,
         registry: _RegistryType,
         cls: Type[Any],
+        originating_module: Optional[str],
         key: str,
         mapped_container: Optional[Type[Mapped[Any]]],
         annotation: Optional[_AnnotationScanType],
@@ -365,7 +366,7 @@ class CompositeProperty(
             self.composite_class = argument
 
         if is_dataclass(self.composite_class):
-            self._setup_for_dataclass(registry, cls, key)
+            self._setup_for_dataclass(registry, cls, originating_module, key)
         else:
             for attr in self.attrs:
                 if (
@@ -408,7 +409,11 @@ class CompositeProperty(
     @util.preload_module("sqlalchemy.orm.properties")
     @util.preload_module("sqlalchemy.orm.decl_base")
     def _setup_for_dataclass(
-        self, registry: _RegistryType, cls: Type[Any], key: str
+        self,
+        registry: _RegistryType,
+        cls: Type[Any],
+        originating_module: Optional[str],
+        key: str,
     ) -> None:
         MappedColumn = util.preloaded.orm_properties.MappedColumn
 
@@ -432,7 +437,12 @@ class CompositeProperty(
 
             if isinstance(attr, MappedColumn):
                 attr.declarative_scan_for_composite(
-                    registry, cls, key, param.name, param.annotation
+                    registry,
+                    cls,
+                    originating_module,
+                    key,
+                    param.name,
+                    param.annotation,
                 )
             elif isinstance(attr, schema.Column):
                 decl_base._undefer_column_name(param.name, attr)
index 1747bfd9b2a905afe70a4f8d224cfbfbdd62c547..e61e82126a8b6f6989af4d17b5357d8b72872cbd 100644 (file)
@@ -157,6 +157,7 @@ class _IntrospectsAnnotations:
         self,
         registry: RegistryType,
         cls: Type[Any],
+        originating_module: Optional[str],
         key: str,
         mapped_container: Optional[Type[Mapped[Any]]],
         annotation: Optional[_AnnotationScanType],
@@ -215,11 +216,12 @@ class _AttributeOptions(NamedTuple):
     def _get_arguments_for_make_dataclass(
         cls,
         key: str,
-        annotation: Type[Any],
+        annotation: _AnnotationScanType,
         mapped_container: Optional[Any],
         elem: _T,
     ) -> Union[
-        Tuple[str, Type[Any]], Tuple[str, Type[Any], dataclasses.Field[Any]]
+        Tuple[str, _AnnotationScanType],
+        Tuple[str, _AnnotationScanType, dataclasses.Field[Any]],
     ]:
         """given attribute key, annotation, and value from a class, return
         the argument tuple we would pass to dataclasses.make_dataclass()
index c67da3905f42bc368a7071782925535dc7b93e66..0cbd3f71355849926541e886f6dd562430c962fe 100644 (file)
@@ -197,6 +197,7 @@ class ColumnProperty(
         self,
         registry: _RegistryType,
         cls: Type[Any],
+        originating_module: Optional[str],
         key: str,
         mapped_container: Optional[Type[Mapped[Any]]],
         annotation: Optional[_AnnotationScanType],
@@ -637,6 +638,7 @@ class MappedColumn(
         self,
         registry: _RegistryType,
         cls: Type[Any],
+        originating_module: Optional[str],
         key: str,
         mapped_container: Optional[Type[Mapped[Any]]],
         annotation: Optional[_AnnotationScanType],
@@ -658,7 +660,7 @@ class MappedColumn(
                 return
 
         self._init_column_for_annotation(
-            cls, registry, extracted_mapped_annotation
+            cls, registry, extracted_mapped_annotation, originating_module
         )
 
     @util.preload_module("sqlalchemy.orm.decl_base")
@@ -666,27 +668,37 @@ class MappedColumn(
         self,
         registry: _RegistryType,
         cls: Type[Any],
+        originating_module: Optional[str],
         key: str,
         param_name: str,
         param_annotation: _AnnotationScanType,
     ) -> None:
         decl_base = util.preloaded.orm_decl_base
         decl_base._undefer_column_name(param_name, self.column)
-        self._init_column_for_annotation(cls, registry, param_annotation)
+        self._init_column_for_annotation(
+            cls, registry, param_annotation, originating_module
+        )
 
     def _init_column_for_annotation(
         self,
         cls: Type[Any],
         registry: _RegistryType,
         argument: _AnnotationScanType,
+        originating_module: Optional[str],
     ) -> None:
         sqltype = self.column.type
 
         if is_fwd_ref(argument):
-            argument = de_stringify_annotation(cls, argument)
+            assert originating_module is not None
+            argument = de_stringify_annotation(
+                cls, argument, originating_module
+            )
 
         if is_union(argument):
-            argument = de_stringify_union_elements(cls, argument)
+            assert originating_module is not None
+            argument = de_stringify_union_elements(
+                cls, argument, originating_module
+            )
 
         nullable = is_optional_union(argument)
 
index 276199da272661f76dd8793685707cb4f62c75ee..6d388a630b59b2226d9c47f6740f9950e9926283 100644 (file)
@@ -1712,6 +1712,7 @@ class RelationshipProperty(
         self,
         registry: _RegistryType,
         cls: Type[Any],
+        originating_module: Optional[str],
         key: str,
         mapped_container: Optional[Type[Mapped[Any]]],
         annotation: Optional[_AnnotationScanType],
index 3302feb70a807bfb1744bc784986f85f80306993..e64c936363fa510b23cfc8d4631e60af3b58be71 100644 (file)
@@ -2020,10 +2020,14 @@ def _getitem(iterable_query: Query[Any], item: Any) -> Any:
 
 
 def _is_mapped_annotation(
-    raw_annotation: _AnnotationScanType, cls: Type[Any]
+    raw_annotation: _AnnotationScanType,
+    cls: Type[Any],
+    originating_cls: Type[Any],
 ) -> bool:
     try:
-        annotated = de_stringify_annotation(cls, raw_annotation)
+        annotated = de_stringify_annotation(
+            cls, raw_annotation, originating_cls.__module__
+        )
     except NameError:
         return False
     else:
@@ -2065,6 +2069,7 @@ def _cleanup_mapped_str_annotation(annotation: str) -> str:
 def _extract_mapped_subtype(
     raw_annotation: Optional[_AnnotationScanType],
     cls: type,
+    originating_module: str,
     key: str,
     attr_cls: Type[Any],
     required: bool,
@@ -2091,7 +2096,10 @@ def _extract_mapped_subtype(
 
     try:
         annotated = de_stringify_annotation(
-            cls, raw_annotation, _cleanup_mapped_str_annotation
+            cls,
+            raw_annotation,
+            originating_module,
+            _cleanup_mapped_str_annotation,
         )
     except NameError as ne:
         if raiseerr and "Mapped[" in raw_annotation:  # type: ignore
index 1d93444476f0eaf9f9257d9cdcf71f89d9fa2b87..b5d918a74e4db5677c60f871d158596b07f44495 100644 (file)
@@ -84,6 +84,7 @@ _LiteralStar = Literal["*"]
 def de_stringify_annotation(
     cls: Type[Any],
     annotation: _AnnotationScanType,
+    originating_module: str,
     str_cleanup_fn: Optional[Callable[[str], str]] = None,
 ) -> Type[Any]:
     """Resolve annotations that may be string based into real objects.
@@ -109,14 +110,14 @@ def de_stringify_annotation(
     if isinstance(annotation, str):
         if str_cleanup_fn:
             annotation = str_cleanup_fn(annotation)
-
         base_globals: "Dict[str, Any]" = getattr(
-            sys.modules.get(cls.__module__, None), "__dict__", {}
+            sys.modules.get(originating_module, None), "__dict__", {}
         )
 
         try:
             annotation = eval(annotation, base_globals, None)
         except NameError as err:
+            # breakpoint()
             raise NameError(
                 f"Could not de-stringify annotation {annotation}"
             ) from err
@@ -126,11 +127,14 @@ def de_stringify_annotation(
 def de_stringify_union_elements(
     cls: Type[Any],
     annotation: _AnnotationScanType,
+    originating_module: str,
     str_cleanup_fn: Optional[Callable[[str], str]] = None,
 ) -> Type[Any]:
     return make_union_type(
         *[
-            de_stringify_annotation(cls, anno, str_cleanup_fn)
+            de_stringify_annotation(
+                cls, anno, originating_module, str_cleanup_fn
+            )
             for anno in annotation.__args__  # type: ignore
         ]
     )
index 24d666508a7bb082068e16bc91a611eda3b92490..03e41e7cfa913f781a410f102b0f0050d507611d 100644 (file)
@@ -6,6 +6,7 @@ from typing import Optional
 from typing import Set
 from typing import TypeVar
 from typing import Union
+import uuid
 
 from sqlalchemy import Column
 from sqlalchemy import exc
@@ -16,6 +17,7 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
+from sqlalchemy import Uuid
 from sqlalchemy.orm import attribute_keyed_dict
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import DynamicMapped
@@ -116,6 +118,26 @@ class MappedColumnTest(_MappedColumnTest):
                 is_(optional_col.type, our_type)
                 is_true(optional_col.nullable)
 
+    def test_typ_not_in_cls_namespace(self, decl_base):
+        """test #8742.
+
+        This tests that when types are resolved, they use the ``__module__``
+        of they class they are used within, not the mapped class.
+
+        """
+
+        class Mixin:
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[uuid.UUID]
+
+        class MyClass(Mixin, decl_base):
+            # basically no type will be resolvable here
+            __module__ = "some.module"
+            __tablename__ = "mytable"
+
+        is_(MyClass.id.expression.type._type_affinity, Integer)
+        is_(MyClass.data.expression.type._type_affinity, Uuid)
+
 
 class MappedOneArg(KeyFuncDict[str, _R]):
     pass