]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
pass **`kwargs` in `__init_subclass__` to super 10733/head
authorMichael Oliver <michael@michaeloliver.dev>
Mon, 4 Dec 2023 14:25:00 +0000 (14:25 +0000)
committerMichael Oliver <michael@michaeloliver.dev>
Tue, 5 Dec 2023 09:49:21 +0000 (09:49 +0000)
Fixes: #10732
lib/sqlalchemy/orm/decl_api.py
test/orm/declarative/test_basic.py

index f2039afcd5405c60de7a1b574f71c4dde6e4388a..b1fc80e5f939543e4fcf3e00fa846ca87b51eb5d 100644 (file)
@@ -594,6 +594,7 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
         dataclass_callable: Union[
             _NoArg, Callable[..., Type[Any]]
         ] = _NoArg.NO_ARG,
+        **kw: Any,
     ) -> None:
         apply_dc_transforms: _DataclassArguments = {
             "init": init,
@@ -622,7 +623,7 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
                 current_transforms
             ) = apply_dc_transforms
 
-        super().__init_subclass__()
+        super().__init_subclass__(**kw)
 
         if not _is_mapped_class(cls):
             new_anno = (
@@ -839,13 +840,13 @@ class DeclarativeBase(
         def __init__(self, **kw: Any):
             ...
 
-    def __init_subclass__(cls) -> None:
+    def __init_subclass__(cls, **kw: Any) -> None:
         if DeclarativeBase in cls.__bases__:
             _check_not_declarative(cls, DeclarativeBase)
             _setup_declarative_base(cls)
         else:
             _as_declarative(cls._sa_registry, cls, cls.__dict__)
-        super().__init_subclass__()
+        super().__init_subclass__(**kw)
 
 
 def _check_not_declarative(cls: Type[Any], base: Type[Any]) -> None:
@@ -964,12 +965,13 @@ class DeclarativeBaseNoMeta(
         def __init__(self, **kw: Any):
             ...
 
-    def __init_subclass__(cls) -> None:
+    def __init_subclass__(cls, **kw: Any) -> None:
         if DeclarativeBaseNoMeta in cls.__bases__:
             _check_not_declarative(cls, DeclarativeBaseNoMeta)
             _setup_declarative_base(cls)
         else:
             _as_declarative(cls._sa_registry, cls, cls.__dict__)
+        super().__init_subclass__(**kw)
 
 
 def add_mapped_attribute(
index 7085b2af9f6b5a8fb3393520a89c7b361364047e..5515ccab6ddc83d98f1fc625291a48406914f1c1 100644 (file)
@@ -35,6 +35,7 @@ from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import MappedAsDataclass
 from sqlalchemy.orm import MappedColumn
 from sqlalchemy.orm import Mapper
 from sqlalchemy.orm import registry
@@ -930,6 +931,57 @@ class DeclarativeBaseSetupsTest(fixtures.TestBase):
         # Check to see if __init_subclass__ works in supported versions
         eq_(UserType._set_random_keyword_used_here, True)
 
+    def test_kw_support_in_declarative_init_subclass(self):
+        # This will not fail if DeclarativeBase __init_subclass__
+        # supports **kw
+        class Base(DeclarativeBase):
+            pass
+
+        class Mixin:
+            def __init_subclass__(cls, random_keyword: bool, **kw) -> None:
+                super().__init_subclass__(**kw)
+                cls._set_random_keyword_used_here = random_keyword
+
+        class User(Base, Mixin, random_keyword=True):
+            __tablename__ = "user"
+            id_ = Column(Integer, primary_key=True)
+
+        eq_(User._set_random_keyword_used_here, True)
+
+    def test_kw_support_in_declarative_no_meta_init_subclass(self):
+        # This will not fail if DeclarativeBaseNoMeta __init_subclass__
+        # supports **kw
+        class Base(DeclarativeBaseNoMeta):
+            pass
+
+        class Mixin:
+            def __init_subclass__(cls, random_keyword: bool, **kw) -> None:
+                super().__init_subclass__(**kw)
+                cls._set_random_keyword_used_here = random_keyword
+
+        class User(Base, Mixin, random_keyword=True):
+            __tablename__ = "user"
+            id_ = Column(Integer, primary_key=True)
+
+        eq_(User._set_random_keyword_used_here, True)
+
+    def test_kw_support_in_mapped_as_dataclass_init_subclass(self):
+        # This will not fail if MappedAsDataclass __init_subclass__
+        # supports **kw
+        class Base(MappedAsDataclass):
+            pass
+
+        class Mixin:
+            def __init_subclass__(cls, random_keyword: bool, **kw) -> None:
+                super().__init_subclass__(**kw)
+                cls._set_random_keyword_used_here = random_keyword
+
+        class User(Base, Mixin, random_keyword=True):
+            __tablename__ = "user"
+            id_ = Column(Integer, primary_key=True)
+
+        eq_(User._set_random_keyword_used_here, True)
+
     def test_declarative_base_bad_registry(self):
         with assertions.expect_raises_message(
             exc.InvalidRequestError,