]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
identify unresolvable Mapped types
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 28 Nov 2022 15:58:49 +0000 (10:58 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 28 Nov 2022 17:01:48 +0000 (12:01 -0500)
Fixed issue where use of an unknown datatype within a :class:`.Mapped`
annotation for a column-based attribute would silently fail to map the
attribute, rather than reporting an exception; an informative exception
message is now raised.

tighten up iteration of names on mapped classes to more fully
exclude a large number of underscored names, so that we can avoid trying
to look at annotations for them or anything else.  centralize the
"list of names we care about" more fully within _cls_attr_resolver
and base it on underscore conventions we should usually ignore,
with the exception of the few underscore names we want to see.

Fixes: #8888
Change-Id: I3c0a1666579fe67b3c40cc74fa443b6f1de354ce

doc/build/changelog/unreleased_20/8888.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/util.py
test/orm/declarative/test_basic.py
test/orm/declarative/test_tm_future_annotations.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/8888.rst b/doc/build/changelog/unreleased_20/8888.rst
new file mode 100644 (file)
index 0000000..61b2168
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 8888
+
+    Fixed issue where use of an unknown datatype within a :class:`.Mapped`
+    annotation for a column-based attribute would silently fail to map the
+    attribute, rather than reporting an exception; an informative exception
+    message is now raised.
index 01766ad850c2b0a74cf1787f66a750d7149408a5..09397eb653fca393c3743cd1d6968f1c91153e37 100644 (file)
@@ -1553,6 +1553,10 @@ class registry:
 
 RegistryType = registry
 
+if not TYPE_CHECKING:
+    # allow for runtime type resolution of ``ClassVar[_RegistryType]``
+    _RegistryType = registry  # noqa
+
 
 def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]:
     """
index c23ea0311aa05d0b192600b0b7942ba33ae30272..21e3c3344dabe84bac4fcec6eb4814808bd5fb95 100644 (file)
@@ -635,36 +635,37 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         return attribute_is_overridden
 
-    _skip_attrs = frozenset(
-        [
-            "__module__",
-            "__annotations__",
-            "__doc__",
-            "__dict__",
-            "__weakref__",
-            "_sa_class_manager",
-            "_sa_apply_dc_transforms",
-            "__dict__",
-            "__weakref__",
-        ]
-    )
+    _include_dunders = {
+        "__table__",
+        "__mapper_args__",
+        "__tablename__",
+        "__table_args__",
+    }
+
+    _match_exclude_dunders = re.compile(r"^(?:_sa_|__)")
 
     def _cls_attr_resolver(
         self, cls: Type[Any]
     ) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]:
-        """produce a function to iterate the "attributes" of a class,
-        adjusting for SQLAlchemy fields embedded in dataclass fields.
+        """produce a function to iterate the "attributes" of a class
+        which we want to consider for mapping, adjusting for SQLAlchemy fields
+        embedded in dataclass fields.
 
         """
         cls_annotations = util.get_annotations(cls)
 
         cls_vars = vars(cls)
 
-        skip = self._skip_attrs
+        _include_dunders = self._include_dunders
+        _match_exclude_dunders = self._match_exclude_dunders
 
-        names = util.merge_lists_w_ordering(
-            [n for n in cls_vars if n not in skip], list(cls_annotations)
-        )
+        names = [
+            n
+            for n in util.merge_lists_w_ordering(
+                list(cls_vars), list(cls_annotations)
+            )
+            if not _match_exclude_dunders.match(n) or n in _include_dunders
+        ]
 
         if self.allow_dataclass_fields:
             sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr(
@@ -719,6 +720,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         clsdict_view = self.clsdict_view
         collected_attributes = self.collected_attributes
         column_copies = self.column_copies
+        _include_dunders = self._include_dunders
         mapper_args_fn = None
         table_args = inherited_table_args = None
 
@@ -784,7 +786,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                 annotation,
                 is_dataclass_field,
             ) in local_attributes_for_class():
-                if re.match(r"^__.+__$", name):
+                if name in _include_dunders:
                     if name == "__mapper_args__":
                         check_decl = _check_declared_props_nocascade(
                             obj, name, cls
@@ -825,7 +827,8 @@ class _ClassScanMapperConfig(_MapperConfig):
                             if base is not cls:
                                 inherited_table_args = True
                     else:
-                        # skip all other dunder names
+                        # skip all other dunder names, which at the moment
+                        # should only be __table__
                         continue
                 elif class_mapped:
                     if _is_declarative_props(obj):
@@ -965,14 +968,19 @@ class _ClassScanMapperConfig(_MapperConfig):
                         name, annotation, base, False, obj
                     )
                 else:
-                    generated_obj = self._collect_annotation(
+                    collected_annotation = self._collect_annotation(
                         name, annotation, base, None, obj
                     )
-                    if (
-                        obj is None
-                        and not fixed_table
-                        and _is_mapped_annotation(annotation, cls, base)
-                    ):
+                    is_mapped = (
+                        collected_annotation is not None
+                        and collected_annotation.mapped_container is not None
+                    )
+                    generated_obj = (
+                        collected_annotation.attr_value
+                        if collected_annotation is not None
+                        else obj
+                    )
+                    if obj is None and not fixed_table and is_mapped:
                         collected_attributes[name] = (
                             generated_obj
                             if generated_obj is not None
@@ -1077,13 +1085,13 @@ class _ClassScanMapperConfig(_MapperConfig):
         originating_class: Type[Any],
         expect_mapped: Optional[bool],
         attr_value: Any,
-    ) -> Any:
+    ) -> Optional[_CollectedAnnotation]:
 
         if name in self.collected_annotations:
-            return self.collected_annotations[name][4]
+            return self.collected_annotations[name]
 
         if raw_annotation is None:
-            return attr_value
+            return None
 
         is_dataclass = self.is_dataclass_prior_to_mapping
         allow_unmapped = self.allow_unmapped_annotations
@@ -1116,7 +1124,7 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         if extracted is None:
             # ClassVar can come out here
-            return attr_value
+            return None
 
         extracted_mapped_annotation, mapped_container = extracted
 
@@ -1136,7 +1144,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                 if isinstance(elem, _IntrospectsAnnotations):
                     attr_value = elem.found_in_pep593_annotated()
 
-        self.collected_annotations[name] = _CollectedAnnotation(
+        self.collected_annotations[name] = ca = _CollectedAnnotation(
             raw_annotation,
             mapped_container,
             extracted_mapped_annotation,
@@ -1144,7 +1152,7 @@ class _ClassScanMapperConfig(_MapperConfig):
             attr_value,
             originating_class.__module__,
         )
-        return attr_value
+        return ca
 
     def _warn_for_decl_attributes(
         self, cls: Type[Any], key: str, c: Any
@@ -1177,9 +1185,14 @@ class _ClassScanMapperConfig(_MapperConfig):
                 and obj is None
                 and _is_mapped_annotation(annotation, cls, originating_class)
             ):
-                obj = self._collect_annotation(
+                collected_annotation = self._collect_annotation(
                     name, annotation, originating_class, True, obj
                 )
+                obj = (
+                    collected_annotation.attr_value
+                    if collected_annotation is not None
+                    else obj
+                )
                 if obj is None:
                     obj = MappedColumn()
 
@@ -1195,9 +1208,14 @@ class _ClassScanMapperConfig(_MapperConfig):
                     # either (issue #8718)
                     continue
 
-                obj = self._collect_annotation(
+                collected_annotation = self._collect_annotation(
                     name, annotation, originating_class, True, obj
                 )
+                obj = (
+                    collected_annotation.attr_value
+                    if collected_annotation is not None
+                    else obj
+                )
 
                 if name not in dict_ and not (
                     "__table__" in dict_
@@ -1233,6 +1251,8 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         our_stuff = self.properties
 
+        _include_dunders = self._include_dunders
+
         late_mapped = _get_immediate_cls_attr(
             cls, "_sa_decl_prepare_nocascade", strict=True
         )
@@ -1244,7 +1264,7 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         for k in list(collected_attributes):
 
-            if k in ("__table__", "__tablename__", "__mapper_args__"):
+            if k in _include_dunders:
                 continue
 
             value = collected_attributes[k]
@@ -1297,11 +1317,12 @@ class _ClassScanMapperConfig(_MapperConfig):
             # we expect to see the name 'metadata' in some valid cases;
             # however at this point we see it's assigned to something trying
             # to be mapped, so raise for that.
-            elif k == "metadata":
+            # TODO: should "registry" here be also?   might be too late
+            # to change that now (2.0 betas)
+            elif k in ("metadata",):
                 raise exc.InvalidRequestError(
-                    "Attribute name 'metadata' is reserved "
-                    "for the MetaData instance when using a "
-                    "declarative base class."
+                    f"Attribute name '{k}' is reserved when using the "
+                    "Declarative API."
                 )
             elif isinstance(value, Column):
                 _undefer_column_name(
@@ -1326,16 +1347,24 @@ class _ClassScanMapperConfig(_MapperConfig):
                     # do declarative_scan so that the property can raise
                     # for required
                     if mapped_container is not None or annotation is None:
-                        value.declarative_scan(
-                            self.registry,
-                            cls,
-                            originating_module,
-                            k,
-                            mapped_container,
-                            annotation,
-                            extracted_mapped_annotation,
-                            is_dataclass,
-                        )
+                        try:
+                            value.declarative_scan(
+                                self.registry,
+                                cls,
+                                originating_module,
+                                k,
+                                mapped_container,
+                                annotation,
+                                extracted_mapped_annotation,
+                                is_dataclass,
+                            )
+                        except NameError as ne:
+                            raise exc.ArgumentError(
+                                f"Could not resolve all types within mapped "
+                                f'annotation: "{annotation}".  Ensure all '
+                                f"types are written correctly and are "
+                                f"imported within the module in use."
+                            ) from ne
                     else:
                         # assert that we were expecting annotations
                         # without Mapped[] were going to be passed.
index 6250cd104adde7a6e0b7defbb70585051ec7a1a8..58407a74d41fb2f986f7c6bc9630969ddb259492 100644 (file)
@@ -2033,6 +2033,12 @@ def _is_mapped_annotation(
             cls, raw_annotation, originating_cls.__module__
         )
     except NameError:
+        # in most cases, at least within our own tests, we can raise
+        # here, which is more accurate as it prevents us from returning
+        # false negatives.  However, in the real world, try to avoid getting
+        # involved with end-user annotations that have nothing to do with us.
+        # see issue #8888 where we bypass using this function in the case
+        # that we want to detect an unresolvable Mapped[] type.
         return False
     else:
         return is_origin_of_cls(annotated, _MappedAnnotationBase)
index 3dfc598272ea6fc0ea450c7c002273374ced6807..475b5e39bb80e3e653df85d136aeeedc664e1bb5 100644 (file)
@@ -46,6 +46,7 @@ from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import assertions
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
@@ -1048,27 +1049,46 @@ class DeclarativeMultiBaseTest(
             configure_mappers,
         )
 
-    def test_reserved_identifiers(self):
-        def go1():
-            class User1(Base):
-                __tablename__ = "user1"
-                id = Column(Integer, primary_key=True)
-                metadata = Column(Integer)
+    # currently "registry" is allowed, "metadata" is not.
+    @testing.combinations(
+        ("metadata", True), ("registry", False), argnames="name, expect_raise"
+    )
+    @testing.variation("attrtype", ["column", "relationship"])
+    def test_reserved_identifiers(
+        self, decl_base, name, expect_raise, attrtype
+    ):
 
-        def go2():
-            class User2(Base):
-                __tablename__ = "user2"
+        if attrtype.column:
+            clsdict = {
+                "__tablename__": "user",
+                "id": Column(Integer, primary_key=True),
+                name: Column(Integer),
+            }
+        elif attrtype.relationship:
+            clsdict = {
+                "__tablename__": "user",
+                "id": Column(Integer, primary_key=True),
+                name: relationship("Address"),
+            }
+
+            class Address(decl_base):
+                __tablename__ = "address"
                 id = Column(Integer, primary_key=True)
-                metadata = relationship("Address")
+                user_id = Column(ForeignKey("user.id"))
 
-        for go in (go1, go2):
-            assert_raises_message(
+        else:
+            assert False
+
+        if expect_raise:
+            with expect_raises_message(
                 exc.InvalidRequestError,
-                "Attribute name 'metadata' is reserved "
-                "for the MetaData instance when using a "
-                "declarative base class.",
-                go,
-            )
+                f"Attribute name '{name}' is reserved "
+                "when using the Declarative API.",
+            ):
+                type("User", (decl_base,), clsdict)
+        else:
+            User = type("User", (decl_base,), clsdict)
+            assert getattr(User, name).property
 
     def test_recompile_on_othermapper(self):
         """declarative version of the same test in mappers.py"""
index 1e8913368ccb46bab472490d7b70983c0cc9f61f..b66d67a77ffc0aaaf3efa5ee09e774175d3035b2 100644 (file)
@@ -160,6 +160,22 @@ class MappedColumnTest(_MappedColumnTest):
         is_(MyClass.id.expression.type._type_affinity, Integer)
         is_(MyClass.data.expression.type._type_affinity, Uuid)
 
+    def test_dont_ignore_unresolvable(self, decl_base):
+        """test #8888"""
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            r"Could not resolve all types within mapped annotation: "
+            r"\"Mapped\[fake\]\".  Ensure all types are written correctly and "
+            r"are imported within the module in use.",
+        ):
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+                data: Mapped[fake]  # noqa
+
 
 class MappedOneArg(KeyFuncDict[str, _R]):
     pass
index 1a4ac6de33d676d7d0d28fb490261640faeac5a5..7358f385db58a420e6b813f5f190ee1d8efac1aa 100644 (file)
@@ -1118,6 +1118,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         assert isinstance(MyClass.__table__.c.data.type, sqltype)
 
+    def test_dont_ignore_unresolvable(self, decl_base):
+        """test #8888"""
+
+        with expect_raises_message(
+            sa_exc.ArgumentError,
+            r"Could not resolve all types within mapped annotation: "
+            r"\".*Mapped\[.*fake.*\]\".  Ensure all types are written "
+            r"correctly and are imported within the module in use.",
+        ):
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+                data: Mapped["fake"]  # noqa
+
 
 class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
index 05ceee3f8849171d246fd41914afe2a613a7ac73..ba099412f3ef204fa656073de458eaa6a805ed66 100644 (file)
@@ -1109,6 +1109,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
         assert isinstance(MyClass.__table__.c.data.type, sqltype)
 
+    def test_dont_ignore_unresolvable(self, decl_base):
+        """test #8888"""
+
+        with expect_raises_message(
+            sa_exc.ArgumentError,
+            r"Could not resolve all types within mapped annotation: "
+            r"\".*Mapped\[.*fake.*\]\".  Ensure all types are written "
+            r"correctly and are imported within the module in use.",
+        ):
+
+            class A(decl_base):
+                __tablename__ = "a"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+                data: Mapped["fake"]  # noqa
+
 
 class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"