]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
restore declared_attr consumption for __table__
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Jun 2024 02:45:16 +0000 (22:45 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Jun 2024 03:23:09 +0000 (23:23 -0400)
Fixed bug in ORM Declarative where the ``__table__`` directive could not be
declared as a class function with :func:`_orm.declared_attr` on a
superclass, including an ``__abstract__`` class as well as coming from the
declarative base itself.  This was a regression since 1.4 where this was
working, and there were apparently no tests for this particular use case.

Fixes: #11509
Change-Id: I82ef0f93d00cb7a43b0b1b16ea28f1a9a79eba3b

doc/build/changelog/unreleased_20/11509.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_base.py
test/orm/declarative/test_mixin.py

diff --git a/doc/build/changelog/unreleased_20/11509.rst b/doc/build/changelog/unreleased_20/11509.rst
new file mode 100644 (file)
index 0000000..1761c2b
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 11509
+
+    Fixed bug in ORM Declarative where the ``__table__`` directive could not be
+    declared as a class function with :func:`_orm.declared_attr` on a
+    superclass, including an ``__abstract__`` class as well as coming from the
+    declarative base itself.  This was a regression since 1.4 where this was
+    working, and there were apparently no tests for this particular use case.
index 0513eac66a01d2f3c37d6261a672bdbd1b31e1a7..90396403c2b10d2eb535dc59e8eae86bcd6c6bd6 100644 (file)
@@ -453,6 +453,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         "tablename",
         "mapper_args",
         "mapper_args_fn",
+        "table_fn",
         "inherits",
         "single",
         "allow_dataclass_fields",
@@ -759,7 +760,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         _include_dunders = self._include_dunders
         mapper_args_fn = None
         table_args = inherited_table_args = None
-
+        table_fn = None
         tablename = None
         fixed_table = "__table__" in clsdict_view
 
@@ -840,6 +841,22 @@ class _ClassScanMapperConfig(_MapperConfig):
                         )
                         if not tablename and (not class_mapped or check_decl):
                             tablename = cls_as_Decl.__tablename__
+                    elif name == "__table__":
+                        check_decl = _check_declared_props_nocascade(
+                            obj, name, cls
+                        )
+                        # if a @declared_attr using "__table__" is detected,
+                        # wrap up a callable to look for "__table__" from
+                        # the final concrete class when we set up a table.
+                        # this was fixed by
+                        # #11509, regression in 2.0 from version 1.4.
+                        if check_decl and not table_fn:
+                            # don't even invoke __table__ until we're ready
+                            def _table_fn() -> FromClause:
+                                return cls_as_Decl.__table__
+
+                            table_fn = _table_fn
+
                     elif name == "__table_args__":
                         check_decl = _check_declared_props_nocascade(
                             obj, name, cls
@@ -856,9 +873,10 @@ class _ClassScanMapperConfig(_MapperConfig):
                             if base is not cls:
                                 inherited_table_args = True
                     else:
-                        # skip all other dunder names, which at the moment
-                        # should only be __table__
-                        continue
+                        # any other dunder names; should not be here
+                        # as we have tested for all four names in
+                        # _include_dunders
+                        assert False
                 elif class_mapped:
                     if _is_declarative_props(obj) and not obj._quiet:
                         util.warn(
@@ -1031,6 +1049,7 @@ class _ClassScanMapperConfig(_MapperConfig):
         self.table_args = table_args
         self.tablename = tablename
         self.mapper_args_fn = mapper_args_fn
+        self.table_fn = table_fn
 
     def _setup_dataclasses_transforms(self) -> None:
         dataclass_setup_arguments = self.dataclass_setup_arguments
@@ -1687,7 +1706,11 @@ class _ClassScanMapperConfig(_MapperConfig):
 
         manager = attributes.manager_of_class(cls)
 
-        if "__table__" not in clsdict_view and table is None:
+        if (
+            self.table_fn is None
+            and "__table__" not in clsdict_view
+            and table is None
+        ):
             if hasattr(cls, "__table_cls__"):
                 table_cls = cast(
                     Type[Table],
@@ -1733,7 +1756,12 @@ class _ClassScanMapperConfig(_MapperConfig):
                 )
         else:
             if table is None:
-                table = cls_as_Decl.__table__
+                if self.table_fn:
+                    table = self.set_cls_attribute(
+                        "__table__", self.table_fn()
+                    )
+                else:
+                    table = cls_as_Decl.__table__
             if declared_columns:
                 for c in declared_columns:
                     if not table.c.contains_column(c):
index 2520eb846d73fa0151383fab6494954fe0dd02ad..d670e96dcbfdd7da7e44ce0cda5aa05396592e3a 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
+from sqlalchemy import schema
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
@@ -98,6 +99,159 @@ class DeclarativeMixinTest(DeclarativeTestBase):
 
         self.assert_compile(select(Foo), "SELECT foo.name, foo.id FROM foo")
 
+    @testing.variation("base_type", ["generate_base", "subclass"])
+    @testing.variation("attrname", ["table", "tablename"])
+    @testing.variation("position", ["base", "abstract"])
+    @testing.variation("assert_no_extra_cols", [True, False])
+    def test_declared_attr_on_base(
+        self, registry, base_type, attrname, position, assert_no_extra_cols
+    ):
+        """test #11509"""
+
+        if position.abstract:
+            if base_type.generate_base:
+                SuperBase = registry.generate_base()
+
+                class Base(SuperBase):
+                    __abstract__ = True
+                    if attrname.table:
+
+                        @declared_attr.directive
+                        def __table__(cls):
+                            return Table(
+                                cls.__name__,
+                                cls.registry.metadata,
+                                Column("id", Integer, primary_key=True),
+                            )
+
+                    elif attrname.tablename:
+
+                        @declared_attr.directive
+                        def __tablename__(cls):
+                            return cls.__name__
+
+                    else:
+                        attrname.fail()
+
+            elif base_type.subclass:
+
+                class SuperBase(DeclarativeBase):
+                    pass
+
+                class Base(SuperBase):
+                    __abstract__ = True
+                    if attrname.table:
+
+                        @declared_attr.directive
+                        def __table__(cls):
+                            return Table(
+                                cls.__name__,
+                                cls.registry.metadata,
+                                Column("id", Integer, primary_key=True),
+                            )
+
+                    elif attrname.tablename:
+
+                        @declared_attr.directive
+                        def __tablename__(cls):
+                            return cls.__name__
+
+                    else:
+                        attrname.fail()
+
+            else:
+                base_type.fail()
+        else:
+            if base_type.generate_base:
+
+                class Base:
+                    if attrname.table:
+
+                        @declared_attr.directive
+                        def __table__(cls):
+                            return Table(
+                                cls.__name__,
+                                cls.registry.metadata,
+                                Column("id", Integer, primary_key=True),
+                            )
+
+                    elif attrname.tablename:
+
+                        @declared_attr.directive
+                        def __tablename__(cls):
+                            return cls.__name__
+
+                    else:
+                        attrname.fail()
+
+                Base = registry.generate_base(cls=Base)
+            elif base_type.subclass:
+
+                class Base(DeclarativeBase):
+                    if attrname.table:
+
+                        @declared_attr.directive
+                        def __table__(cls):
+                            return Table(
+                                cls.__name__,
+                                cls.registry.metadata,
+                                Column("id", Integer, primary_key=True),
+                            )
+
+                    elif attrname.tablename:
+
+                        @declared_attr.directive
+                        def __tablename__(cls):
+                            return cls.__name__
+
+                    else:
+                        attrname.fail()
+
+            else:
+                base_type.fail()
+
+        if attrname.table and assert_no_extra_cols:
+            with expect_raises_message(
+                sa.exc.ArgumentError,
+                "Can't add additional column 'data' when specifying __table__",
+            ):
+
+                class MyNopeClass(Base):
+                    data = Column(String)
+
+            return
+
+        class MyClass(Base):
+            if attrname.tablename:
+                id = Column(Integer, primary_key=True)  # noqa: A001
+
+        class MyOtherClass(Base):
+            if attrname.tablename:
+                id = Column(Integer, primary_key=True)  # noqa: A001
+
+        t = Table(
+            "my_override",
+            Base.metadata,
+            Column("id", Integer, primary_key=True),
+        )
+
+        class MyOverrideClass(Base):
+            __table__ = t
+
+        Base.registry.configure()
+
+        # __table__ was assigned
+        assert isinstance(MyClass.__dict__["__table__"], schema.Table)
+        assert isinstance(MyOtherClass.__dict__["__table__"], schema.Table)
+
+        eq_(MyClass.__table__.name, "MyClass")
+        eq_(MyClass.__table__.c.keys(), ["id"])
+
+        eq_(MyOtherClass.__table__.name, "MyOtherClass")
+        eq_(MyOtherClass.__table__.c.keys(), ["id"])
+
+        is_(MyOverrideClass.__table__, t)
+
     def test_simple_wbase(self):
         class MyMixin:
             id = Column(