]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support DeclarativeBase for versioned history example
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Mar 2023 13:48:58 +0000 (09:48 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Mar 2023 13:55:57 +0000 (09:55 -0400)
Fixed issue in "versioned history" example where using a declarative base
that is derived from :class:`_orm.DeclarativeBase` would fail to be mapped.
Additionally, repaired the given test suite so that the documented
instructions for running the example using Python unittest now work again.

Change-Id: I164a5b8dbdd01e3d815eb356f7b7cadf226ca296
References: #9546

doc/build/changelog/unreleased_20/vers_fixes.rst [new file with mode: 0644]
examples/versioned_history/__init__.py
examples/versioned_history/history_meta.py
examples/versioned_history/test_versioning.py
test/base/test_examples.py

diff --git a/doc/build/changelog/unreleased_20/vers_fixes.rst b/doc/build/changelog/unreleased_20/vers_fixes.rst
new file mode 100644 (file)
index 0000000..d4f6411
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, examples
+
+    Fixed issue in "versioned history" example where using a declarative base
+    that is derived from :class:`_orm.DeclarativeBase` would fail to be mapped.
+    Additionally, repaired the given test suite so that the documented
+    instructions for running the example using Python unittest now work again.
index 14dbea5c0539c9853002ad1863eae179712381ff..0593881e2de3fca7d1c21a3cfc1e50faf1b77c58 100644 (file)
@@ -16,7 +16,8 @@ A fragment of example usage, using declarative::
 
     from history_meta import Versioned, versioned_session
 
-    Base = declarative_base()
+    class Base(DeclarativeBase):
+        pass
 
     class SomeClass(Versioned, Base):
         __tablename__ = 'sometable'
index 1176a5dffbf54959088a0beccd999648d9f5a2c9..cc3ef2b0a5bea96ae89eb655330fb86e5207e307 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy import Column
 from sqlalchemy import DateTime
 from sqlalchemy import event
 from sqlalchemy import ForeignKeyConstraint
+from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import PrimaryKeyConstraint
 from sqlalchemy import util
@@ -174,22 +175,22 @@ def _history_mapper(local_mapper):
     else:
         bases = local_mapper.base_mapper.class_.__bases__
 
-    versioned_cls = type.__new__(
-        type,
+    versioned_cls = type(
         "%sHistory" % cls.__name__,
         bases,
-        {"_history_mapper_configured": True},
+        {
+            "_history_mapper_configured": True,
+            "__table__": history_table,
+            "__mapper_args__": dict(
+                inherits=super_history_mapper,
+                polymorphic_identity=local_mapper.polymorphic_identity,
+                polymorphic_on=polymorphic_on,
+                properties=properties,
+            ),
+        },
     )
 
-    m = local_mapper.registry.map_imperatively(
-        versioned_cls,
-        history_table,
-        inherits=super_history_mapper,
-        polymorphic_identity=local_mapper.polymorphic_identity,
-        polymorphic_on=polymorphic_on,
-        properties=properties,
-    )
-    cls.__history_mapper__ = m
+    cls.__history_mapper__ = versioned_cls.__mapper__
 
 
 class Versioned:
@@ -201,9 +202,17 @@ class Versioned:
     are used for new rows even for rows that have been deleted."""
 
     def __init_subclass__(cls) -> None:
-        @event.listens_for(cls, "after_mapper_constructed")
-        def _mapper_constructed(mapper, class_):
-            _history_mapper(mapper)
+        insp = inspect(cls, raiseerr=False)
+
+        if insp is not None:
+            _history_mapper(insp)
+        else:
+
+            @event.listens_for(cls, "after_mapper_constructed")
+            def _mapper_constructed(mapper, class_):
+                _history_mapper(mapper)
+
+        super().__init_subclass__()
 
 
 def versioned_objects(iter_):
index 8f963559223d3d32481a9bf2a41df83ad0a20a2f..9caadc04332f2f122e3d8918fc84e7daf46ce5be 100644 (file)
@@ -8,12 +8,15 @@ from sqlalchemy import Boolean
 from sqlalchemy import Column
 from sqlalchemy import create_engine
 from sqlalchemy import ForeignKey
+from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import join
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.orm import clear_mappers
 from sqlalchemy.orm import column_property
+from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import deferred
 from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm import relationship
@@ -21,6 +24,8 @@ from sqlalchemy.orm import Session
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import eq_ignore_whitespace
+from sqlalchemy.testing import is_
 from sqlalchemy.testing import ne_
 from sqlalchemy.testing.entities import ComparableEntity
 from .history_meta import Versioned
@@ -37,7 +42,7 @@ class TestVersioning(AssertsCompiledSQL):
 
         self.engine = engine = create_engine("sqlite://")
         self.session = Session(engine)
-        self.Base = declarative_base()
+        self.make_base()
         versioned_session(self.session)
 
     def tearDown(self):
@@ -45,6 +50,9 @@ class TestVersioning(AssertsCompiledSQL):
         clear_mappers()
         self.Base.metadata.drop_all(self.engine)
 
+    def make_base(self):
+        self.Base = declarative_base()
+
     def create_tables(self):
         self.Base.metadata.create_all(self.engine)
 
@@ -120,6 +128,37 @@ class TestVersioning(AssertsCompiledSQL):
             ],
         )
 
+    def test_discussion_9546(self):
+        class ThingExternal(Versioned, self.Base):
+            __tablename__ = "things_external"
+            id = Column(Integer, primary_key=True)
+            external_attribute = Column(String)
+
+        class ThingLocal(Versioned, self.Base):
+            __tablename__ = "things_local"
+            id = Column(
+                Integer, ForeignKey(ThingExternal.id), primary_key=True
+            )
+            internal_attribute = Column(String)
+
+        is_(ThingExternal.__table__, inspect(ThingExternal).local_table)
+
+        class Thing(self.Base):
+            __table__ = join(ThingExternal, ThingLocal)
+            id = column_property(ThingExternal.id, ThingLocal.id)
+            version = column_property(
+                ThingExternal.version, ThingLocal.version
+            )
+
+        eq_ignore_whitespace(
+            str(select(Thing)),
+            "SELECT things_external.id, things_local.id AS id_1, "
+            "things_external.external_attribute, things_external.version, "
+            "things_local.version AS version_1, "
+            "things_local.internal_attribute FROM things_external "
+            "JOIN things_local ON things_external.id = things_local.id",
+        )
+
     def test_w_mapper_versioning(self):
         class SomeClass(Versioned, self.Base, ComparableEntity):
             __tablename__ = "sometable"
@@ -750,7 +789,19 @@ class TestVersioning(AssertsCompiledSQL):
         sess.commit()
 
 
-class TestVersioningUnittest(unittest.TestCase, TestVersioning):
+class TestVersioningNewBase(TestVersioning):
+    def make_base(self):
+        class Base(DeclarativeBase):
+            pass
+
+        self.Base = Base
+
+
+class TestVersioningUnittest(TestVersioning, unittest.TestCase):
+    pass
+
+
+class TestVersioningNewBaseUnittest(TestVersioningNewBase, unittest.TestCase):
     pass
 
 
index 50f0c01f2cc2f8bba8495de247c035ed6d2291b2..4baddfb105ae9713aab1f2e83461f1c02d7a92ad 100644 (file)
@@ -15,9 +15,17 @@ test_versioning = __import__(
 ).versioned_history.test_versioning
 
 
-class VersionedRowsTest(
+class VersionedRowsTestLegacyBase(
     test_versioning.TestVersioning,
     fixtures.RemoveORMEventsGlobally,
     fixtures.TestBase,
 ):
     pass
+
+
+class VersionedRowsTestNewBase(
+    test_versioning.TestVersioningNewBase,
+    fixtures.RemoveORMEventsGlobally,
+    fixtures.TestBase,
+):
+    pass