From 0c50f8dfdeb8adf997cbc8aa03443e8e47761cb3 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 28 Nov 2022 10:58:49 -0500 Subject: [PATCH] identify unresolvable Mapped types Fixed issue where use of an unknown datatype within a :class:`.Mapped` annotation for a column-based attribute would silently fail to map the attribute, rather than reporting an exception; an informative exception message is now raised. tighten up iteration of names on mapped classes to more fully exclude a large number of underscored names, so that we can avoid trying to look at annotations for them or anything else. centralize the "list of names we care about" more fully within _cls_attr_resolver and base it on underscore conventions we should usually ignore, with the exception of the few underscore names we want to see. Fixes: #8888 Change-Id: I3c0a1666579fe67b3c40cc74fa443b6f1de354ce --- doc/build/changelog/unreleased_20/8888.rst | 8 ++ lib/sqlalchemy/orm/decl_api.py | 4 + lib/sqlalchemy/orm/decl_base.py | 129 +++++++++++------- lib/sqlalchemy/orm/util.py | 6 + test/orm/declarative/test_basic.py | 54 +++++--- .../declarative/test_tm_future_annotations.py | 16 +++ .../test_tm_future_annotations_sync.py | 16 +++ test/orm/declarative/test_typed_mapping.py | 16 +++ 8 files changed, 182 insertions(+), 67 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/8888.rst diff --git a/doc/build/changelog/unreleased_20/8888.rst b/doc/build/changelog/unreleased_20/8888.rst new file mode 100644 index 0000000000..61b216804e --- /dev/null +++ b/doc/build/changelog/unreleased_20/8888.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, orm + :tickets: 8888 + + Fixed issue where use of an unknown datatype within a :class:`.Mapped` + annotation for a column-based attribute would silently fail to map the + attribute, rather than reporting an exception; an informative exception + message is now raised. diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 01766ad850..09397eb653 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -1553,6 +1553,10 @@ class registry: RegistryType = registry +if not TYPE_CHECKING: + # allow for runtime type resolution of ``ClassVar[_RegistryType]`` + _RegistryType = registry # noqa + def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]: """ diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index c23ea0311a..21e3c3344d 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -635,36 +635,37 @@ class _ClassScanMapperConfig(_MapperConfig): return attribute_is_overridden - _skip_attrs = frozenset( - [ - "__module__", - "__annotations__", - "__doc__", - "__dict__", - "__weakref__", - "_sa_class_manager", - "_sa_apply_dc_transforms", - "__dict__", - "__weakref__", - ] - ) + _include_dunders = { + "__table__", + "__mapper_args__", + "__tablename__", + "__table_args__", + } + + _match_exclude_dunders = re.compile(r"^(?:_sa_|__)") def _cls_attr_resolver( self, cls: Type[Any] ) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]: - """produce a function to iterate the "attributes" of a class, - adjusting for SQLAlchemy fields embedded in dataclass fields. + """produce a function to iterate the "attributes" of a class + which we want to consider for mapping, adjusting for SQLAlchemy fields + embedded in dataclass fields. """ cls_annotations = util.get_annotations(cls) cls_vars = vars(cls) - skip = self._skip_attrs + _include_dunders = self._include_dunders + _match_exclude_dunders = self._match_exclude_dunders - names = util.merge_lists_w_ordering( - [n for n in cls_vars if n not in skip], list(cls_annotations) - ) + names = [ + n + for n in util.merge_lists_w_ordering( + list(cls_vars), list(cls_annotations) + ) + if not _match_exclude_dunders.match(n) or n in _include_dunders + ] if self.allow_dataclass_fields: sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( @@ -719,6 +720,7 @@ class _ClassScanMapperConfig(_MapperConfig): clsdict_view = self.clsdict_view collected_attributes = self.collected_attributes column_copies = self.column_copies + _include_dunders = self._include_dunders mapper_args_fn = None table_args = inherited_table_args = None @@ -784,7 +786,7 @@ class _ClassScanMapperConfig(_MapperConfig): annotation, is_dataclass_field, ) in local_attributes_for_class(): - if re.match(r"^__.+__$", name): + if name in _include_dunders: if name == "__mapper_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -825,7 +827,8 @@ class _ClassScanMapperConfig(_MapperConfig): if base is not cls: inherited_table_args = True else: - # skip all other dunder names + # skip all other dunder names, which at the moment + # should only be __table__ continue elif class_mapped: if _is_declarative_props(obj): @@ -965,14 +968,19 @@ class _ClassScanMapperConfig(_MapperConfig): name, annotation, base, False, obj ) else: - generated_obj = self._collect_annotation( + collected_annotation = self._collect_annotation( name, annotation, base, None, obj ) - if ( - obj is None - and not fixed_table - and _is_mapped_annotation(annotation, cls, base) - ): + is_mapped = ( + collected_annotation is not None + and collected_annotation.mapped_container is not None + ) + generated_obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) + if obj is None and not fixed_table and is_mapped: collected_attributes[name] = ( generated_obj if generated_obj is not None @@ -1077,13 +1085,13 @@ class _ClassScanMapperConfig(_MapperConfig): originating_class: Type[Any], expect_mapped: Optional[bool], attr_value: Any, - ) -> Any: + ) -> Optional[_CollectedAnnotation]: if name in self.collected_annotations: - return self.collected_annotations[name][4] + return self.collected_annotations[name] if raw_annotation is None: - return attr_value + return None is_dataclass = self.is_dataclass_prior_to_mapping allow_unmapped = self.allow_unmapped_annotations @@ -1116,7 +1124,7 @@ class _ClassScanMapperConfig(_MapperConfig): if extracted is None: # ClassVar can come out here - return attr_value + return None extracted_mapped_annotation, mapped_container = extracted @@ -1136,7 +1144,7 @@ class _ClassScanMapperConfig(_MapperConfig): if isinstance(elem, _IntrospectsAnnotations): attr_value = elem.found_in_pep593_annotated() - self.collected_annotations[name] = _CollectedAnnotation( + self.collected_annotations[name] = ca = _CollectedAnnotation( raw_annotation, mapped_container, extracted_mapped_annotation, @@ -1144,7 +1152,7 @@ class _ClassScanMapperConfig(_MapperConfig): attr_value, originating_class.__module__, ) - return attr_value + return ca def _warn_for_decl_attributes( self, cls: Type[Any], key: str, c: Any @@ -1177,9 +1185,14 @@ class _ClassScanMapperConfig(_MapperConfig): and obj is None and _is_mapped_annotation(annotation, cls, originating_class) ): - obj = self._collect_annotation( + collected_annotation = self._collect_annotation( name, annotation, originating_class, True, obj ) + obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) if obj is None: obj = MappedColumn() @@ -1195,9 +1208,14 @@ class _ClassScanMapperConfig(_MapperConfig): # either (issue #8718) continue - obj = self._collect_annotation( + collected_annotation = self._collect_annotation( name, annotation, originating_class, True, obj ) + obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) if name not in dict_ and not ( "__table__" in dict_ @@ -1233,6 +1251,8 @@ class _ClassScanMapperConfig(_MapperConfig): our_stuff = self.properties + _include_dunders = self._include_dunders + late_mapped = _get_immediate_cls_attr( cls, "_sa_decl_prepare_nocascade", strict=True ) @@ -1244,7 +1264,7 @@ class _ClassScanMapperConfig(_MapperConfig): for k in list(collected_attributes): - if k in ("__table__", "__tablename__", "__mapper_args__"): + if k in _include_dunders: continue value = collected_attributes[k] @@ -1297,11 +1317,12 @@ class _ClassScanMapperConfig(_MapperConfig): # we expect to see the name 'metadata' in some valid cases; # however at this point we see it's assigned to something trying # to be mapped, so raise for that. - elif k == "metadata": + # TODO: should "registry" here be also? might be too late + # to change that now (2.0 betas) + elif k in ("metadata",): raise exc.InvalidRequestError( - "Attribute name 'metadata' is reserved " - "for the MetaData instance when using a " - "declarative base class." + f"Attribute name '{k}' is reserved when using the " + "Declarative API." ) elif isinstance(value, Column): _undefer_column_name( @@ -1326,16 +1347,24 @@ class _ClassScanMapperConfig(_MapperConfig): # do declarative_scan so that the property can raise # for required if mapped_container is not None or annotation is None: - value.declarative_scan( - self.registry, - cls, - originating_module, - k, - mapped_container, - annotation, - extracted_mapped_annotation, - is_dataclass, - ) + try: + value.declarative_scan( + self.registry, + cls, + originating_module, + k, + mapped_container, + annotation, + extracted_mapped_annotation, + is_dataclass, + ) + except NameError as ne: + raise exc.ArgumentError( + f"Could not resolve all types within mapped " + f'annotation: "{annotation}". Ensure all ' + f"types are written correctly and are " + f"imported within the module in use." + ) from ne else: # assert that we were expecting annotations # without Mapped[] were going to be passed. diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 6250cd104a..58407a74d4 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -2033,6 +2033,12 @@ def _is_mapped_annotation( cls, raw_annotation, originating_cls.__module__ ) except NameError: + # in most cases, at least within our own tests, we can raise + # here, which is more accurate as it prevents us from returning + # false negatives. However, in the real world, try to avoid getting + # involved with end-user annotations that have nothing to do with us. + # see issue #8888 where we bypass using this function in the case + # that we want to detect an unresolvable Mapped[] type. return False else: return is_origin_of_cls(annotated, _MappedAnnotationBase) diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 3dfc598272..475b5e39bb 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -46,6 +46,7 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import assertions from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -1048,27 +1049,46 @@ class DeclarativeMultiBaseTest( configure_mappers, ) - def test_reserved_identifiers(self): - def go1(): - class User1(Base): - __tablename__ = "user1" - id = Column(Integer, primary_key=True) - metadata = Column(Integer) + # currently "registry" is allowed, "metadata" is not. + @testing.combinations( + ("metadata", True), ("registry", False), argnames="name, expect_raise" + ) + @testing.variation("attrtype", ["column", "relationship"]) + def test_reserved_identifiers( + self, decl_base, name, expect_raise, attrtype + ): - def go2(): - class User2(Base): - __tablename__ = "user2" + if attrtype.column: + clsdict = { + "__tablename__": "user", + "id": Column(Integer, primary_key=True), + name: Column(Integer), + } + elif attrtype.relationship: + clsdict = { + "__tablename__": "user", + "id": Column(Integer, primary_key=True), + name: relationship("Address"), + } + + class Address(decl_base): + __tablename__ = "address" id = Column(Integer, primary_key=True) - metadata = relationship("Address") + user_id = Column(ForeignKey("user.id")) - for go in (go1, go2): - assert_raises_message( + else: + assert False + + if expect_raise: + with expect_raises_message( exc.InvalidRequestError, - "Attribute name 'metadata' is reserved " - "for the MetaData instance when using a " - "declarative base class.", - go, - ) + f"Attribute name '{name}' is reserved " + "when using the Declarative API.", + ): + type("User", (decl_base,), clsdict) + else: + User = type("User", (decl_base,), clsdict) + assert getattr(User, name).property def test_recompile_on_othermapper(self): """declarative version of the same test in mappers.py""" diff --git a/test/orm/declarative/test_tm_future_annotations.py b/test/orm/declarative/test_tm_future_annotations.py index 1e8913368c..b66d67a77f 100644 --- a/test/orm/declarative/test_tm_future_annotations.py +++ b/test/orm/declarative/test_tm_future_annotations.py @@ -160,6 +160,22 @@ class MappedColumnTest(_MappedColumnTest): is_(MyClass.id.expression.type._type_affinity, Integer) is_(MyClass.data.expression.type._type_affinity, Uuid) + def test_dont_ignore_unresolvable(self, decl_base): + """test #8888""" + + with expect_raises_message( + exc.ArgumentError, + r"Could not resolve all types within mapped annotation: " + r"\"Mapped\[fake\]\". Ensure all types are written correctly and " + r"are imported within the module in use.", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[fake] # noqa + class MappedOneArg(KeyFuncDict[str, _R]): pass diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index 1a4ac6de33..7358f385db 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -1118,6 +1118,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): assert isinstance(MyClass.__table__.c.data.type, sqltype) + def test_dont_ignore_unresolvable(self, decl_base): + """test #8888""" + + with expect_raises_message( + sa_exc.ArgumentError, + r"Could not resolve all types within mapped annotation: " + r"\".*Mapped\[.*fake.*\]\". Ensure all types are written " + r"correctly and are imported within the module in use.", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped["fake"] # noqa + class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 05ceee3f88..ba099412f3 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1109,6 +1109,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): assert isinstance(MyClass.__table__.c.data.type, sqltype) + def test_dont_ignore_unresolvable(self, decl_base): + """test #8888""" + + with expect_raises_message( + sa_exc.ArgumentError, + r"Could not resolve all types within mapped annotation: " + r"\".*Mapped\[.*fake.*\]\". Ensure all types are written " + r"correctly and are imported within the module in use.", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped["fake"] # noqa + class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" -- 2.47.2