]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
read from cls.__dict__ so init_subclass works
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 4 Apr 2022 23:01:54 +0000 (19:01 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Apr 2022 19:10:32 +0000 (15:10 -0400)
Modified the :class:`.DeclarativeMeta` metaclass to pass ``cls.__dict__``
into the declarative scanning process to look for attributes, rather than
the separate dictionary passed to the type's ``__init__()`` method. This
allows user-defined base classes that add attributes within an
``__init_subclass__()`` to work as expected, as ``__init_subclass__()`` can
only affect the ``cls.__dict__`` itself and not the other dictionary. This
is technically a regression from 1.3 where ``__dict__`` was being used.

Additionally makes the reference between ClassManager and
the declarative configuration object a weak reference, so that it
can be discarded after mappers are set up.

Fixes: #7900
Change-Id: I3c2fd4e227cc1891aa4bb3d7d5b43d5686f9f27c

doc/build/changelog/unreleased_14/7900.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/instrumentation.py
test/orm/declarative/test_mixin.py

diff --git a/doc/build/changelog/unreleased_14/7900.rst b/doc/build/changelog/unreleased_14/7900.rst
new file mode 100644 (file)
index 0000000..9d6d507
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: bug, orm, declarative
+    :tickets: 7900
+
+    Modified the :class:`.DeclarativeMeta` metaclass to pass ``cls.__dict__``
+    into the declarative scanning process to look for attributes, rather than
+    the separate dictionary passed to the type's ``__init__()`` method. This
+    allows user-defined base classes that add attributes within an
+    ``__init_subclass__()`` to work as expected, as ``__init_subclass__()`` can
+    only affect the ``cls.__dict__`` itself and not the other dictionary. This
+    is technically a regression from 1.3 where ``__dict__`` was being used.
+
+
+
index 4a699f63bbe41a6528f8a56dec4bca9033217510..70507015bce5048cb2d6cf062a0e686e6f8b5d2e 100644 (file)
@@ -109,6 +109,10 @@ class DeclarativeMeta(
     def __init__(
         cls, classname: Any, bases: Any, dict_: Any, **kw: Any
     ) -> None:
+        # use cls.__dict__, which can be modified by an
+        # __init_subclass__() method (#7900)
+        dict_ = cls.__dict__
+
         # early-consume registry from the initial declarative base,
         # assign privately to not conflict with subclass attributes named
         # "registry"
@@ -293,7 +297,8 @@ class declared_attr(interfaces._MappedAttribute[_T]):
 
         # here, we are inside of the declarative scan.  use the registry
         # that is tracking the values of these attributes.
-        declarative_scan = manager.declarative_scan
+        declarative_scan = manager.declarative_scan()
+        assert declarative_scan is not None
         reg = declarative_scan.declared_attr_reg
 
         if self in reg:
index 3fb8af80ca9488d6518d2fdd2ca103b98ca4f409..804d05ce1930b4f412cfc34830e1b14fa8fe151b 100644 (file)
@@ -161,7 +161,13 @@ def _check_declared_props_nocascade(obj, name, cls):
 
 
 class _MapperConfig:
-    __slots__ = ("cls", "classname", "properties", "declared_attr_reg")
+    __slots__ = (
+        "cls",
+        "classname",
+        "properties",
+        "declared_attr_reg",
+        "__weakref__",
+    )
 
     @classmethod
     def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw):
@@ -311,13 +317,15 @@ class _ClassScanMapperConfig(_MapperConfig):
         mapper_kw,
     ):
 
+        # grab class dict before the instrumentation manager has been added.
+        # reduces cycles
+        self.clsdict_view = (
+            util.immutabledict(dict_) if dict_ else util.EMPTY_DICT
+        )
         super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw)
         self.registry = registry
         self.persist_selectable = None
 
-        self.clsdict_view = (
-            util.immutabledict(dict_) if dict_ else util.EMPTY_DICT
-        )
         self.collected_attributes = {}
         self.collected_annotations: Dict[str, Tuple[Any, bool]] = {}
         self.declared_columns = util.OrderedSet()
index a5dc305d22472120d490d8cf6db538b6b1a3f591..0d4b630dad80c895b728784e578383838db09514 100644 (file)
@@ -39,6 +39,7 @@ from typing import Optional
 from typing import Set
 from typing import TYPE_CHECKING
 from typing import TypeVar
+import weakref
 
 from . import base
 from . import collections
@@ -167,7 +168,7 @@ class ClassManager(
         if registry:
             registry._add_manager(self)
         if declarative_scan:
-            self.declarative_scan = declarative_scan
+            self.declarative_scan = weakref.ref(declarative_scan)
         if expired_attribute_loader:
             self.expired_attribute_loader = expired_attribute_loader
 
index 5be8237e26899715f7058b7d1947d31eb2498823..36840b2d7a788aa1987c096a497b07f73b75207e 100644 (file)
@@ -43,7 +43,11 @@ Base = None
 mapper_registry = None
 
 
-class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
+class DeclarativeTestBase(
+    testing.AssertsCompiledSQL,
+    fixtures.TestBase,
+    testing.AssertsExecutionResults,
+):
     def setup_test(self):
         global Base, mapper_registry
 
@@ -58,6 +62,19 @@ class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
 
 
 class DeclarativeMixinTest(DeclarativeTestBase):
+    def test_init_subclass_works(self, registry):
+        class Base:
+            def __init_subclass__(cls):
+                cls.id = Column(Integer, primary_key=True)
+
+        Base = registry.generate_base(cls=Base)
+
+        class Foo(Base):
+            __tablename__ = "foo"
+            name = Column(String)
+
+        self.assert_compile(select(Foo), "SELECT foo.name, foo.id FROM foo")
+
     def test_simple_wbase(self):
         class MyMixin: