From 5777e268d092889d1764ed13cb62d007cdbeb0b5 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Fri, 15 Jul 2022 21:55:52 +0200 Subject: [PATCH] Ensure that a daclarative base is not used directly Fixes: #8248 Change-Id: I4f4c690dd8659eaf74e9c757d681e9edc7d33eee --- lib/sqlalchemy/orm/decl_api.py | 34 ++++++++++++++------ test/orm/declarative/test_basic.py | 51 ++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 10 deletions(-) diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 7249698c00..500f2786e3 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -522,7 +522,7 @@ def declarative_mixin(cls: Type[_T]) -> Type[_T]: def _setup_declarative_base(cls: Type[Any]) -> None: if "metadata" in cls.__dict__: - metadata = cls.metadata # type: ignore + metadata = cls.__dict__["metadata"] else: metadata = None @@ -688,11 +688,27 @@ class DeclarativeBase( def __init_subclass__(cls) -> None: if DeclarativeBase in cls.__bases__: + _check_not_declarative(cls, DeclarativeBase) _setup_declarative_base(cls) else: _as_declarative(cls._sa_registry, cls, cls.__dict__) +def _check_not_declarative(cls: Type[Any], base: Type[Any]) -> None: + cls_dict = cls.__dict__ + if ( + "__table__" in cls_dict + and not ( + callable(cls_dict["__table__"]) + or hasattr(cls_dict["__table__"], "__get__") + ) + ) or isinstance(cls_dict.get("__tablename__", None), str): + raise exc.InvalidRequestError( + f"Cannot use {base.__name__!r} directly as a declarative base " + "class. Create a Base by creating a subclass of it." + ) + + class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]): """Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass to intercept new attributes. @@ -705,22 +721,20 @@ class DeclarativeBaseNoMeta(inspection.Inspectable[Mapper[Any]]): """ - if typing.TYPE_CHECKING: - registry: ClassVar[_RegistryType] - _sa_registry: ClassVar[_RegistryType] - metadata: ClassVar[MetaData] - - __name__: ClassVar[str] - __mapper__: ClassVar[Mapper[Any]] - __table__: ClassVar[Optional[FromClause]] + registry: ClassVar[_RegistryType] + _sa_registry: ClassVar[_RegistryType] + metadata: ClassVar[MetaData] + __mapper__: ClassVar[Mapper[Any]] + __table__: Optional[FromClause] - __tablename__: ClassVar[Any] + if typing.TYPE_CHECKING: def __init__(self, **kw: Any): ... def __init_subclass__(cls) -> None: if DeclarativeBaseNoMeta in cls.__bases__: + _check_not_declarative(cls, DeclarativeBaseNoMeta) _setup_declarative_base(cls) else: cls._sa_registry.map_declaratively(cls) diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 4990056c3e..e93286e40d 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -554,6 +554,57 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): class MyClass(DeclarativeBase): registry = {"foo": "bar"} + def test_declarative_base_registry_and_type_map(self): + with assertions.expect_raises_message( + exc.InvalidRequestError, + "Declarative base class has both a 'registry' attribute and a " + "type_annotation_map entry. Per-base type_annotation_maps", + ): + + class MyClass(DeclarativeBase): + registry = registry() + type_annotation_map = {int: Integer} + + @testing.combinations(DeclarativeBase, DeclarativeBaseNoMeta) + def test_declarative_base_used_directly(self, base): + with assertions.expect_raises_message( + exc.InvalidRequestError, + f"Cannot use {base.__name__!r} directly as a declarative base", + ): + + class MyClass(base): + __tablename__ = "foobar" + id: int = mapped_column(primary_key=True) + + with assertions.expect_raises_message( + exc.InvalidRequestError, + f"Cannot use {base.__name__!r} directly as a declarative base", + ): + + class MyClass2(base): + __table__ = sa.Table( + "foobar", + sa.MetaData(), + sa.Column("id", Integer, primary_key=True), + ) + + @testing.combinations(DeclarativeBase, DeclarativeBaseNoMeta) + def test_declarative_base_fn_ok(self, base): + # __tablename__ or __table__ as declared_attr are ok in the base + class MyBase1(base): + @declared_attr + def __tablename__(cls): + return cls.__name__ + + class MyBase2(base): + @declared_attr + def __table__(cls): + return sa.Table( + "foobar", + sa.MetaData(), + sa.Column("id", Integer, primary_key=True), + ) + @testing.combinations( ("declarative_base_nometa_superclass",), -- 2.47.2