From 289894f9af4bebee499969ee8701e06eb8527913 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 4 Apr 2022 19:01:54 -0400 Subject: [PATCH] read from cls.__dict__ so init_subclass works Modified the :class:`.DeclarativeMeta` metaclass to pass ``cls.__dict__`` into the declarative scanning process to look for attributes, rather than the separate dictionary passed to the type's ``__init__()`` method. This allows user-defined base classes that add attributes within an ``__init_subclass__()`` to work as expected, as ``__init_subclass__()`` can only affect the ``cls.__dict__`` itself and not the other dictionary. This is technically a regression from 1.3 where ``__dict__`` was being used. Additionally makes the reference between ClassManager and the declarative configuration object a weak reference, so that it can be discarded after mappers are set up. Fixes: #7900 Change-Id: I3c2fd4e227cc1891aa4bb3d7d5b43d5686f9f27c (cherry picked from commit 428ea01f00a9cc7f85e435018565eb6da7af1b77) --- doc/build/changelog/unreleased_14/7900.rst | 14 ++++++++++++++ lib/sqlalchemy/orm/decl_api.py | 7 ++++++- lib/sqlalchemy/orm/decl_base.py | 13 +++++++++++-- lib/sqlalchemy/orm/instrumentation.py | 4 +++- test/orm/declarative/test_mixin.py | 20 +++++++++++++++++++- 5 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/7900.rst diff --git a/doc/build/changelog/unreleased_14/7900.rst b/doc/build/changelog/unreleased_14/7900.rst new file mode 100644 index 0000000000..9d6d507703 --- /dev/null +++ b/doc/build/changelog/unreleased_14/7900.rst @@ -0,0 +1,14 @@ +.. change:: + :tags: bug, orm, declarative + :tickets: 7900 + + Modified the :class:`.DeclarativeMeta` metaclass to pass ``cls.__dict__`` + into the declarative scanning process to look for attributes, rather than + the separate dictionary passed to the type's ``__init__()`` method. This + allows user-defined base classes that add attributes within an + ``__init_subclass__()`` to work as expected, as ``__init_subclass__()`` can + only affect the ``cls.__dict__`` itself and not the other dictionary. This + is technically a regression from 1.3 where ``__dict__`` was being used. + + + diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 4b6c710c72..16f91c69dd 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -54,6 +54,10 @@ def has_inherited_table(cls): class DeclarativeMeta(type): def __init__(cls, classname, bases, dict_, **kw): + # use cls.__dict__, which can be modified by an + # __init_subclass__() method (#7900) + dict_ = cls.__dict__ + # early-consume registry from the initial declarative base, # assign privately to not conflict with subclass attributes named # "registry" @@ -228,7 +232,8 @@ class declared_attr(interfaces._MappedAttribute, property): # here, we are inside of the declarative scan. use the registry # that is tracking the values of these attributes. - declarative_scan = manager.declarative_scan + declarative_scan = manager.declarative_scan() + assert declarative_scan is not None reg = declarative_scan.declared_attr_reg if desc in reg: diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 6f02e56977..ed4ccd1968 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -152,7 +152,13 @@ def _check_declared_props_nocascade(obj, name, cls): class _MapperConfig(object): - __slots__ = ("cls", "classname", "properties", "declared_attr_reg") + __slots__ = ( + "cls", + "classname", + "properties", + "declared_attr_reg", + "__weakref__", + ) @classmethod def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw): @@ -300,9 +306,12 @@ class _ClassScanMapperConfig(_MapperConfig): mapper_kw, ): + # grab class dict before the instrumentation manager has been added. + # reduces cycles + self.dict_ = dict(dict_) if dict_ else {} + super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw) - self.dict_ = dict(dict_) if dict_ else {} self.persist_selectable = None self.declared_columns = set() self.column_copies = {} diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 97692b6421..a7023a21d9 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -30,6 +30,8 @@ alternate instrumentation forms. """ +import weakref + from . import base from . import collections from . import exc @@ -131,7 +133,7 @@ class ClassManager(HasMemoized, dict): if registry: registry._add_manager(self) if declarative_scan: - self.declarative_scan = declarative_scan + self.declarative_scan = weakref.ref(declarative_scan) if expired_attribute_loader: self.expired_attribute_loader = expired_attribute_loader diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index 5a4673a23e..78ab4dbfc3 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -38,7 +38,11 @@ Base = None mapper_registry = None -class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): +class DeclarativeTestBase( + testing.AssertsCompiledSQL, + fixtures.TestBase, + testing.AssertsExecutionResults, +): def setup_test(self): global Base, mapper_registry @@ -53,6 +57,20 @@ class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults): class DeclarativeMixinTest(DeclarativeTestBase): + @testing.requires.python3 + def test_init_subclass_works(self, registry): + class Base: + def __init_subclass__(cls): + cls.id = Column(Integer, primary_key=True) + + Base = registry.generate_base(cls=Base) + + class Foo(Base): + __tablename__ = "foo" + name = Column(String) + + self.assert_compile(select(Foo), "SELECT foo.name, foo.id FROM foo") + def test_simple_wbase(self): class MyMixin(object): -- 2.47.2