from sqlalchemy import DateTime
from sqlalchemy import event
from sqlalchemy import ForeignKeyConstraint
+from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import util
else:
bases = local_mapper.base_mapper.class_.__bases__
- versioned_cls = type.__new__(
- type,
+ versioned_cls = type(
"%sHistory" % cls.__name__,
bases,
- {"_history_mapper_configured": True},
+ {
+ "_history_mapper_configured": True,
+ "__table__": history_table,
+ "__mapper_args__": dict(
+ inherits=super_history_mapper,
+ polymorphic_identity=local_mapper.polymorphic_identity,
+ polymorphic_on=polymorphic_on,
+ properties=properties,
+ ),
+ },
)
- m = local_mapper.registry.map_imperatively(
- versioned_cls,
- history_table,
- inherits=super_history_mapper,
- polymorphic_identity=local_mapper.polymorphic_identity,
- polymorphic_on=polymorphic_on,
- properties=properties,
- )
- cls.__history_mapper__ = m
+ cls.__history_mapper__ = versioned_cls.__mapper__
class Versioned:
are used for new rows even for rows that have been deleted."""
def __init_subclass__(cls) -> None:
- @event.listens_for(cls, "after_mapper_constructed")
- def _mapper_constructed(mapper, class_):
- _history_mapper(mapper)
+ insp = inspect(cls, raiseerr=False)
+
+ if insp is not None:
+ _history_mapper(insp)
+ else:
+
+ @event.listens_for(cls, "after_mapper_constructed")
+ def _mapper_constructed(mapper, class_):
+ _history_mapper(mapper)
+
+ super().__init_subclass__()
def versioned_objects(iter_):
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import ForeignKey
+from sqlalchemy import inspect
from sqlalchemy import Integer
+from sqlalchemy import join
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import clear_mappers
from sqlalchemy.orm import column_property
+from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import deferred
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import relationship
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import eq_ignore_whitespace
+from sqlalchemy.testing import is_
from sqlalchemy.testing import ne_
from sqlalchemy.testing.entities import ComparableEntity
from .history_meta import Versioned
self.engine = engine = create_engine("sqlite://")
self.session = Session(engine)
- self.Base = declarative_base()
+ self.make_base()
versioned_session(self.session)
def tearDown(self):
clear_mappers()
self.Base.metadata.drop_all(self.engine)
+ def make_base(self):
+ self.Base = declarative_base()
+
def create_tables(self):
self.Base.metadata.create_all(self.engine)
],
)
+ def test_discussion_9546(self):
+ class ThingExternal(Versioned, self.Base):
+ __tablename__ = "things_external"
+ id = Column(Integer, primary_key=True)
+ external_attribute = Column(String)
+
+ class ThingLocal(Versioned, self.Base):
+ __tablename__ = "things_local"
+ id = Column(
+ Integer, ForeignKey(ThingExternal.id), primary_key=True
+ )
+ internal_attribute = Column(String)
+
+ is_(ThingExternal.__table__, inspect(ThingExternal).local_table)
+
+ class Thing(self.Base):
+ __table__ = join(ThingExternal, ThingLocal)
+ id = column_property(ThingExternal.id, ThingLocal.id)
+ version = column_property(
+ ThingExternal.version, ThingLocal.version
+ )
+
+ eq_ignore_whitespace(
+ str(select(Thing)),
+ "SELECT things_external.id, things_local.id AS id_1, "
+ "things_external.external_attribute, things_external.version, "
+ "things_local.version AS version_1, "
+ "things_local.internal_attribute FROM things_external "
+ "JOIN things_local ON things_external.id = things_local.id",
+ )
+
def test_w_mapper_versioning(self):
class SomeClass(Versioned, self.Base, ComparableEntity):
__tablename__ = "sometable"
sess.commit()
-class TestVersioningUnittest(unittest.TestCase, TestVersioning):
+class TestVersioningNewBase(TestVersioning):
+ def make_base(self):
+ class Base(DeclarativeBase):
+ pass
+
+ self.Base = Base
+
+
+class TestVersioningUnittest(TestVersioning, unittest.TestCase):
+ pass
+
+
+class TestVersioningNewBaseUnittest(TestVersioningNewBase, unittest.TestCase):
pass