From: Mike Bayer Date: Tue, 18 Jun 2024 02:45:16 +0000 (-0400) Subject: restore declared_attr consumption for __table__ X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=30ec43440168fa79a4d45db64c387562ef8f97e6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git restore declared_attr consumption for __table__ Fixed bug in ORM Declarative where the ``__table__`` directive could not be declared as a class function with :func:`_orm.declared_attr` on a superclass, including an ``__abstract__`` class as well as coming from the declarative base itself. This was a regression since 1.4 where this was working, and there were apparently no tests for this particular use case. Fixes: #11509 Change-Id: I82ef0f93d00cb7a43b0b1b16ea28f1a9a79eba3b --- diff --git a/doc/build/changelog/unreleased_20/11509.rst b/doc/build/changelog/unreleased_20/11509.rst new file mode 100644 index 0000000000..1761c2bf7a --- /dev/null +++ b/doc/build/changelog/unreleased_20/11509.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm, regression + :tickets: 11509 + + Fixed bug in ORM Declarative where the ``__table__`` directive could not be + declared as a class function with :func:`_orm.declared_attr` on a + superclass, including an ``__abstract__`` class as well as coming from the + declarative base itself. This was a regression since 1.4 where this was + working, and there were apparently no tests for this particular use case. diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 0513eac66a..90396403c2 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -453,6 +453,7 @@ class _ClassScanMapperConfig(_MapperConfig): "tablename", "mapper_args", "mapper_args_fn", + "table_fn", "inherits", "single", "allow_dataclass_fields", @@ -759,7 +760,7 @@ class _ClassScanMapperConfig(_MapperConfig): _include_dunders = self._include_dunders mapper_args_fn = None table_args = inherited_table_args = None - + table_fn = None tablename = None fixed_table = "__table__" in clsdict_view @@ -840,6 +841,22 @@ class _ClassScanMapperConfig(_MapperConfig): ) if not tablename and (not class_mapped or check_decl): tablename = cls_as_Decl.__tablename__ + elif name == "__table__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + # if a @declared_attr using "__table__" is detected, + # wrap up a callable to look for "__table__" from + # the final concrete class when we set up a table. + # this was fixed by + # #11509, regression in 2.0 from version 1.4. + if check_decl and not table_fn: + # don't even invoke __table__ until we're ready + def _table_fn() -> FromClause: + return cls_as_Decl.__table__ + + table_fn = _table_fn + elif name == "__table_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -856,9 +873,10 @@ class _ClassScanMapperConfig(_MapperConfig): if base is not cls: inherited_table_args = True else: - # skip all other dunder names, which at the moment - # should only be __table__ - continue + # any other dunder names; should not be here + # as we have tested for all four names in + # _include_dunders + assert False elif class_mapped: if _is_declarative_props(obj) and not obj._quiet: util.warn( @@ -1031,6 +1049,7 @@ class _ClassScanMapperConfig(_MapperConfig): self.table_args = table_args self.tablename = tablename self.mapper_args_fn = mapper_args_fn + self.table_fn = table_fn def _setup_dataclasses_transforms(self) -> None: dataclass_setup_arguments = self.dataclass_setup_arguments @@ -1687,7 +1706,11 @@ class _ClassScanMapperConfig(_MapperConfig): manager = attributes.manager_of_class(cls) - if "__table__" not in clsdict_view and table is None: + if ( + self.table_fn is None + and "__table__" not in clsdict_view + and table is None + ): if hasattr(cls, "__table_cls__"): table_cls = cast( Type[Table], @@ -1733,7 +1756,12 @@ class _ClassScanMapperConfig(_MapperConfig): ) else: if table is None: - table = cls_as_Decl.__table__ + if self.table_fn: + table = self.set_cls_attribute( + "__table__", self.table_fn() + ) + else: + table = cls_as_Decl.__table__ if declared_columns: for c in declared_columns: if not table.c.contains_column(c): diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index 2520eb846d..d670e96dcb 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -7,6 +7,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import MetaData +from sqlalchemy import schema from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing @@ -98,6 +99,159 @@ class DeclarativeMixinTest(DeclarativeTestBase): self.assert_compile(select(Foo), "SELECT foo.name, foo.id FROM foo") + @testing.variation("base_type", ["generate_base", "subclass"]) + @testing.variation("attrname", ["table", "tablename"]) + @testing.variation("position", ["base", "abstract"]) + @testing.variation("assert_no_extra_cols", [True, False]) + def test_declared_attr_on_base( + self, registry, base_type, attrname, position, assert_no_extra_cols + ): + """test #11509""" + + if position.abstract: + if base_type.generate_base: + SuperBase = registry.generate_base() + + class Base(SuperBase): + __abstract__ = True + if attrname.table: + + @declared_attr.directive + def __table__(cls): + return Table( + cls.__name__, + cls.registry.metadata, + Column("id", Integer, primary_key=True), + ) + + elif attrname.tablename: + + @declared_attr.directive + def __tablename__(cls): + return cls.__name__ + + else: + attrname.fail() + + elif base_type.subclass: + + class SuperBase(DeclarativeBase): + pass + + class Base(SuperBase): + __abstract__ = True + if attrname.table: + + @declared_attr.directive + def __table__(cls): + return Table( + cls.__name__, + cls.registry.metadata, + Column("id", Integer, primary_key=True), + ) + + elif attrname.tablename: + + @declared_attr.directive + def __tablename__(cls): + return cls.__name__ + + else: + attrname.fail() + + else: + base_type.fail() + else: + if base_type.generate_base: + + class Base: + if attrname.table: + + @declared_attr.directive + def __table__(cls): + return Table( + cls.__name__, + cls.registry.metadata, + Column("id", Integer, primary_key=True), + ) + + elif attrname.tablename: + + @declared_attr.directive + def __tablename__(cls): + return cls.__name__ + + else: + attrname.fail() + + Base = registry.generate_base(cls=Base) + elif base_type.subclass: + + class Base(DeclarativeBase): + if attrname.table: + + @declared_attr.directive + def __table__(cls): + return Table( + cls.__name__, + cls.registry.metadata, + Column("id", Integer, primary_key=True), + ) + + elif attrname.tablename: + + @declared_attr.directive + def __tablename__(cls): + return cls.__name__ + + else: + attrname.fail() + + else: + base_type.fail() + + if attrname.table and assert_no_extra_cols: + with expect_raises_message( + sa.exc.ArgumentError, + "Can't add additional column 'data' when specifying __table__", + ): + + class MyNopeClass(Base): + data = Column(String) + + return + + class MyClass(Base): + if attrname.tablename: + id = Column(Integer, primary_key=True) # noqa: A001 + + class MyOtherClass(Base): + if attrname.tablename: + id = Column(Integer, primary_key=True) # noqa: A001 + + t = Table( + "my_override", + Base.metadata, + Column("id", Integer, primary_key=True), + ) + + class MyOverrideClass(Base): + __table__ = t + + Base.registry.configure() + + # __table__ was assigned + assert isinstance(MyClass.__dict__["__table__"], schema.Table) + assert isinstance(MyOtherClass.__dict__["__table__"], schema.Table) + + eq_(MyClass.__table__.name, "MyClass") + eq_(MyClass.__table__.c.keys(), ["id"]) + + eq_(MyOtherClass.__table__.name, "MyOtherClass") + eq_(MyOtherClass.__table__.c.keys(), ["id"]) + + is_(MyOverrideClass.__table__, t) + def test_simple_wbase(self): class MyMixin: id = Column(