From: Mike Bayer Date: Mon, 27 Mar 2023 13:48:58 +0000 (-0400) Subject: support DeclarativeBase for versioned history example X-Git-Tag: rel_2_0_8~15^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=24dd3d8c90876a05377d04910819dcd5d25aed4e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git support DeclarativeBase for versioned history example Fixed issue in "versioned history" example where using a declarative base that is derived from :class:`_orm.DeclarativeBase` would fail to be mapped. Additionally, repaired the given test suite so that the documented instructions for running the example using Python unittest now work again. Change-Id: I164a5b8dbdd01e3d815eb356f7b7cadf226ca296 References: #9546 --- diff --git a/doc/build/changelog/unreleased_20/vers_fixes.rst b/doc/build/changelog/unreleased_20/vers_fixes.rst new file mode 100644 index 0000000000..d4f6411510 --- /dev/null +++ b/doc/build/changelog/unreleased_20/vers_fixes.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, examples + + Fixed issue in "versioned history" example where using a declarative base + that is derived from :class:`_orm.DeclarativeBase` would fail to be mapped. + Additionally, repaired the given test suite so that the documented + instructions for running the example using Python unittest now work again. diff --git a/examples/versioned_history/__init__.py b/examples/versioned_history/__init__.py index 14dbea5c05..0593881e2d 100644 --- a/examples/versioned_history/__init__.py +++ b/examples/versioned_history/__init__.py @@ -16,7 +16,8 @@ A fragment of example usage, using declarative:: from history_meta import Versioned, versioned_session - Base = declarative_base() + class Base(DeclarativeBase): + pass class SomeClass(Versioned, Base): __tablename__ = 'sometable' diff --git a/examples/versioned_history/history_meta.py b/examples/versioned_history/history_meta.py index 1176a5dffb..cc3ef2b0a5 100644 --- a/examples/versioned_history/history_meta.py +++ b/examples/versioned_history/history_meta.py @@ -6,6 +6,7 @@ from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import event from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import util @@ -174,22 +175,22 @@ def _history_mapper(local_mapper): else: bases = local_mapper.base_mapper.class_.__bases__ - versioned_cls = type.__new__( - type, + versioned_cls = type( "%sHistory" % cls.__name__, bases, - {"_history_mapper_configured": True}, + { + "_history_mapper_configured": True, + "__table__": history_table, + "__mapper_args__": dict( + inherits=super_history_mapper, + polymorphic_identity=local_mapper.polymorphic_identity, + polymorphic_on=polymorphic_on, + properties=properties, + ), + }, ) - m = local_mapper.registry.map_imperatively( - versioned_cls, - history_table, - inherits=super_history_mapper, - polymorphic_identity=local_mapper.polymorphic_identity, - polymorphic_on=polymorphic_on, - properties=properties, - ) - cls.__history_mapper__ = m + cls.__history_mapper__ = versioned_cls.__mapper__ class Versioned: @@ -201,9 +202,17 @@ class Versioned: are used for new rows even for rows that have been deleted.""" def __init_subclass__(cls) -> None: - @event.listens_for(cls, "after_mapper_constructed") - def _mapper_constructed(mapper, class_): - _history_mapper(mapper) + insp = inspect(cls, raiseerr=False) + + if insp is not None: + _history_mapper(insp) + else: + + @event.listens_for(cls, "after_mapper_constructed") + def _mapper_constructed(mapper, class_): + _history_mapper(mapper) + + super().__init_subclass__() def versioned_objects(iter_): diff --git a/examples/versioned_history/test_versioning.py b/examples/versioned_history/test_versioning.py index 8f96355922..9caadc0433 100644 --- a/examples/versioned_history/test_versioning.py +++ b/examples/versioned_history/test_versioning.py @@ -8,12 +8,15 @@ from sqlalchemy import Boolean from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey +from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import join from sqlalchemy import select from sqlalchemy import String from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import column_property +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import deferred from sqlalchemy.orm import exc as orm_exc from sqlalchemy.orm import relationship @@ -21,6 +24,8 @@ from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import eq_ignore_whitespace +from sqlalchemy.testing import is_ from sqlalchemy.testing import ne_ from sqlalchemy.testing.entities import ComparableEntity from .history_meta import Versioned @@ -37,7 +42,7 @@ class TestVersioning(AssertsCompiledSQL): self.engine = engine = create_engine("sqlite://") self.session = Session(engine) - self.Base = declarative_base() + self.make_base() versioned_session(self.session) def tearDown(self): @@ -45,6 +50,9 @@ class TestVersioning(AssertsCompiledSQL): clear_mappers() self.Base.metadata.drop_all(self.engine) + def make_base(self): + self.Base = declarative_base() + def create_tables(self): self.Base.metadata.create_all(self.engine) @@ -120,6 +128,37 @@ class TestVersioning(AssertsCompiledSQL): ], ) + def test_discussion_9546(self): + class ThingExternal(Versioned, self.Base): + __tablename__ = "things_external" + id = Column(Integer, primary_key=True) + external_attribute = Column(String) + + class ThingLocal(Versioned, self.Base): + __tablename__ = "things_local" + id = Column( + Integer, ForeignKey(ThingExternal.id), primary_key=True + ) + internal_attribute = Column(String) + + is_(ThingExternal.__table__, inspect(ThingExternal).local_table) + + class Thing(self.Base): + __table__ = join(ThingExternal, ThingLocal) + id = column_property(ThingExternal.id, ThingLocal.id) + version = column_property( + ThingExternal.version, ThingLocal.version + ) + + eq_ignore_whitespace( + str(select(Thing)), + "SELECT things_external.id, things_local.id AS id_1, " + "things_external.external_attribute, things_external.version, " + "things_local.version AS version_1, " + "things_local.internal_attribute FROM things_external " + "JOIN things_local ON things_external.id = things_local.id", + ) + def test_w_mapper_versioning(self): class SomeClass(Versioned, self.Base, ComparableEntity): __tablename__ = "sometable" @@ -750,7 +789,19 @@ class TestVersioning(AssertsCompiledSQL): sess.commit() -class TestVersioningUnittest(unittest.TestCase, TestVersioning): +class TestVersioningNewBase(TestVersioning): + def make_base(self): + class Base(DeclarativeBase): + pass + + self.Base = Base + + +class TestVersioningUnittest(TestVersioning, unittest.TestCase): + pass + + +class TestVersioningNewBaseUnittest(TestVersioningNewBase, unittest.TestCase): pass diff --git a/test/base/test_examples.py b/test/base/test_examples.py index 50f0c01f2c..4baddfb105 100644 --- a/test/base/test_examples.py +++ b/test/base/test_examples.py @@ -15,9 +15,17 @@ test_versioning = __import__( ).versioned_history.test_versioning -class VersionedRowsTest( +class VersionedRowsTestLegacyBase( test_versioning.TestVersioning, fixtures.RemoveORMEventsGlobally, fixtures.TestBase, ): pass + + +class VersionedRowsTestNewBase( + test_versioning.TestVersioningNewBase, + fixtures.RemoveORMEventsGlobally, + fixtures.TestBase, +): + pass