From 7fdeec1f3224f48213c9c9af5f3e7e5d0904dafa Mon Sep 17 00:00:00 2001 From: Michael Oliver Date: Mon, 4 Dec 2023 14:25:00 +0000 Subject: [PATCH] pass **`kwargs` in `__init_subclass__` to super Fixes: #10732 --- lib/sqlalchemy/orm/decl_api.py | 10 +++--- test/orm/declarative/test_basic.py | 52 ++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index f2039afcd5..b1fc80e5f9 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -594,6 +594,7 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): dataclass_callable: Union[ _NoArg, Callable[..., Type[Any]] ] = _NoArg.NO_ARG, + **kw: Any, ) -> None: apply_dc_transforms: _DataclassArguments = { "init": init, @@ -622,7 +623,7 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): current_transforms ) = apply_dc_transforms - super().__init_subclass__() + super().__init_subclass__(**kw) if not _is_mapped_class(cls): new_anno = ( @@ -839,13 +840,13 @@ class DeclarativeBase( def __init__(self, **kw: Any): ... - def __init_subclass__(cls) -> None: + def __init_subclass__(cls, **kw: Any) -> None: if DeclarativeBase in cls.__bases__: _check_not_declarative(cls, DeclarativeBase) _setup_declarative_base(cls) else: _as_declarative(cls._sa_registry, cls, cls.__dict__) - super().__init_subclass__() + super().__init_subclass__(**kw) def _check_not_declarative(cls: Type[Any], base: Type[Any]) -> None: @@ -964,12 +965,13 @@ class DeclarativeBaseNoMeta( def __init__(self, **kw: Any): ... - def __init_subclass__(cls) -> None: + def __init_subclass__(cls, **kw: Any) -> None: if DeclarativeBaseNoMeta in cls.__bases__: _check_not_declarative(cls, DeclarativeBaseNoMeta) _setup_declarative_base(cls) else: _as_declarative(cls._sa_registry, cls, cls.__dict__) + super().__init_subclass__(**kw) def add_mapped_attribute( diff --git a/test/orm/declarative/test_basic.py b/test/orm/declarative/test_basic.py index 7085b2af9f..5515ccab6d 100644 --- a/test/orm/declarative/test_basic.py +++ b/test/orm/declarative/test_basic.py @@ -35,6 +35,7 @@ from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import joinedload from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import MappedAsDataclass from sqlalchemy.orm import MappedColumn from sqlalchemy.orm import Mapper from sqlalchemy.orm import registry @@ -930,6 +931,57 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase): # Check to see if __init_subclass__ works in supported versions eq_(UserType._set_random_keyword_used_here, True) + def test_kw_support_in_declarative_init_subclass(self): + # This will not fail if DeclarativeBase __init_subclass__ + # supports **kw + class Base(DeclarativeBase): + pass + + class Mixin: + def __init_subclass__(cls, random_keyword: bool, **kw) -> None: + super().__init_subclass__(**kw) + cls._set_random_keyword_used_here = random_keyword + + class User(Base, Mixin, random_keyword=True): + __tablename__ = "user" + id_ = Column(Integer, primary_key=True) + + eq_(User._set_random_keyword_used_here, True) + + def test_kw_support_in_declarative_no_meta_init_subclass(self): + # This will not fail if DeclarativeBaseNoMeta __init_subclass__ + # supports **kw + class Base(DeclarativeBaseNoMeta): + pass + + class Mixin: + def __init_subclass__(cls, random_keyword: bool, **kw) -> None: + super().__init_subclass__(**kw) + cls._set_random_keyword_used_here = random_keyword + + class User(Base, Mixin, random_keyword=True): + __tablename__ = "user" + id_ = Column(Integer, primary_key=True) + + eq_(User._set_random_keyword_used_here, True) + + def test_kw_support_in_mapped_as_dataclass_init_subclass(self): + # This will not fail if MappedAsDataclass __init_subclass__ + # supports **kw + class Base(MappedAsDataclass): + pass + + class Mixin: + def __init_subclass__(cls, random_keyword: bool, **kw) -> None: + super().__init_subclass__(**kw) + cls._set_random_keyword_used_here = random_keyword + + class User(Base, Mixin, random_keyword=True): + __tablename__ = "user" + id_ = Column(Integer, primary_key=True) + + eq_(User._set_random_keyword_used_here, True) + def test_declarative_base_bad_registry(self): with assertions.expect_raises_message( exc.InvalidRequestError, -- 2.47.3