def before_configured():
self.cls.__declare_first__()
+ def _cls_attr_override_checker(self, cls):
+ """Produce a function that checks if a class has overridden an
+ attribute, taking SQLAlchemy-enabled dataclass fields into account.
+
+ """
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__", None
+ )
+
+ if sa_dataclass_metadata_key is None:
+
+ def attribute_is_overridden(key, obj):
+ return getattr(cls, key) is not obj
+
+ else:
+
+ all_datacls_fields = {
+ f.name: f.metadata[sa_dataclass_metadata_key]
+ for f in util.dataclass_fields(cls)
+ if sa_dataclass_metadata_key in f.metadata
+ }
+ local_datacls_fields = {
+ f.name: f.metadata[sa_dataclass_metadata_key]
+ for f in util.local_dataclass_fields(cls)
+ if sa_dataclass_metadata_key in f.metadata
+ }
+
+ absent = object()
+
+ def attribute_is_overridden(key, obj):
+ # this function likely has some failure modes still if
+ # someone is doing a deep mixing of the same attribute
+ # name as plain Python attribute vs. dataclass field.
+
+ ret = local_datacls_fields.get(key, absent)
+
+ if ret is obj:
+ return False
+ elif ret is not absent:
+ return True
+
+ ret = getattr(cls, key, obj)
+
+ if ret is obj:
+ return False
+ elif ret is not absent:
+ return True
+
+ ret = all_datacls_fields.get(key, absent)
+
+ if ret is obj:
+ return False
+ elif ret is not absent:
+ return True
+
+ # can't find another attribute
+ return False
+
+ return attribute_is_overridden
+
+ def _cls_attr_resolver(self, cls):
+ """produce a function to iterate the "attributes" of a class,
+ adjusting for SQLAlchemy fields embedded in dataclass fields.
+
+ """
+ sa_dataclass_metadata_key = _get_immediate_cls_attr(
+ cls, "__sa_dataclass_metadata_key__", None
+ )
+
+ if sa_dataclass_metadata_key is None:
+
+ def local_attributes_for_class():
+ for name, obj in vars(cls).items():
+ yield name, obj
+
+ else:
+
+ def local_attributes_for_class():
+ for name, obj in vars(cls).items():
+ yield name, obj
+ for field in util.local_dataclass_fields(cls):
+ if sa_dataclass_metadata_key in field.metadata:
+ yield field.name, field.metadata[
+ sa_dataclass_metadata_key
+ ]
+
+ return local_attributes_for_class
+
def _scan_attributes(self):
cls = self.cls
dict_ = self.dict_
table_args = inherited_table_args = None
tablename = None
- for base in cls.__mro__:
+ attribute_is_overridden = self._cls_attr_override_checker(self.cls)
- sa_dataclass_metadata_key = None
+ for base in cls.__mro__:
class_mapped = (
base is not cls
)
)
- if sa_dataclass_metadata_key is None:
- sa_dataclass_metadata_key = _get_immediate_cls_attr(
- base, "__sa_dataclass_metadata_key__", None
- )
-
- def attributes_for_class(cls):
- for name, obj in vars(cls).items():
- yield name, obj
- if sa_dataclass_metadata_key:
- for field in util.dataclass_fields(cls):
- if sa_dataclass_metadata_key in field.metadata:
- yield field.name, field.metadata[
- sa_dataclass_metadata_key
- ]
+ local_attributes_for_class = self._cls_attr_resolver(base)
if not class_mapped and base is not cls:
- self._produce_column_copies(attributes_for_class, base)
+ self._produce_column_copies(
+ local_attributes_for_class, attribute_is_overridden
+ )
- for name, obj in attributes_for_class(base):
+ for name, obj in local_attributes_for_class():
if name == "__mapper_args__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
else:
self._warn_for_decl_attributes(base, name, obj)
elif name not in dict_ or dict_[name] is not obj:
+ # here, we are definitely looking at the target class
+ # and not a superclass. this is currently a
+ # dataclass-only path. if the name is only
+ # a dataclass field and isn't in local cls.__dict__,
+ # put the object there.
+
+ # assert that the dataclass-enabled resolver agrees
+ # with what we are seeing
+ assert not attribute_is_overridden(name, obj)
dict_[name] = obj
if inherited_table_args and not tablename:
% (key, cls)
)
- def _produce_column_copies(self, attributes_for_class, base):
+ def _produce_column_copies(
+ self, attributes_for_class, attribute_is_overridden
+ ):
cls = self.cls
dict_ = self.dict_
column_copies = self.column_copies
# copy mixin columns to the mapped class
- for name, obj in attributes_for_class(base):
+
+ for name, obj in attributes_for_class():
if isinstance(obj, Column):
- if getattr(cls, name) is not obj:
+ if attribute_is_overridden(name, obj):
# if column has been overridden
# (like by the InstrumentedAttribute of the
# superclass), skip
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
+from sqlalchemy.orm import clear_mappers
+from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import mapper
from sqlalchemy.orm import registry as declarative_registry
+from sqlalchemy.orm import registry
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
assert Widget("Foo") != Widget("Bar")
assert Widget("Foo") != SpecialWidget("Foo")
- def test_asdict_and_astuple(self):
+ def test_asdict_and_astuple_widget(self):
Widget = self.classes.Widget
- SpecialWidget = self.classes.SpecialWidget
-
widget = Widget("Foo")
eq_(dataclasses.asdict(widget), {"name": "Foo"})
eq_(dataclasses.astuple(widget), ("Foo",))
+ def test_asdict_and_astuple_special_widget(self):
+ SpecialWidget = self.classes.SpecialWidget
widget = SpecialWidget("Bar", magic=True)
eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True})
eq_(dataclasses.astuple(widget), ("Bar", True))
Account = self.classes.Account
account = self.data_fixture()
- with Session(testing.db) as session:
+ with fixture_session() as session:
session.add(account)
session.commit()
- with Session(testing.db) as session:
+ with fixture_session() as session:
a = session.query(Account).get(42)
self.check_data_fixture(a)
def define_tables(cls, metadata):
pass
- def test_asdict_and_astuple(self):
+ def test_asdict_and_astuple_widget(self):
Widget = self.classes.Widget
- SpecialWidget = self.classes.SpecialWidget
widget = Widget("Foo")
eq_(dataclasses.asdict(widget), {"name": "Foo"})
eq_(dataclasses.astuple(widget), ("Foo",))
+ def test_asdict_and_astuple_special_widget(self):
+ SpecialWidget = self.classes.SpecialWidget
widget = SpecialWidget("Bar", magic=True)
eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True})
eq_(dataclasses.astuple(widget), ("Bar", True))
+
+
+class FieldEmbeddedWMixinTest(FieldEmbeddedDeclarativeDataclassesTest):
+ __requires__ = ("dataclasses",)
+
+ @classmethod
+ def setup_classes(cls):
+ declarative = cls.DeclarativeBasic.registry.mapped
+
+ @dataclasses.dataclass
+ class SurrogateWidgetPK:
+
+ __sa_dataclass_metadata_key__ = "sa"
+
+ widget_id: int = dataclasses.field(
+ init=False,
+ metadata={"sa": Column(Integer, primary_key=True)},
+ )
+
+ @declarative
+ @dataclasses.dataclass
+ class Widget(SurrogateWidgetPK):
+ __tablename__ = "widgets"
+ __sa_dataclass_metadata_key__ = "sa"
+
+ account_id = Column(
+ Integer,
+ ForeignKey("accounts.account_id"),
+ nullable=False,
+ )
+ type = Column(String(30), nullable=False)
+
+ name: Optional[str] = dataclasses.field(
+ default=None,
+ metadata={"sa": Column(String(30), nullable=False)},
+ )
+ __mapper_args__ = dict(
+ polymorphic_on="type",
+ polymorphic_identity="normal",
+ )
+
+ @declarative
+ @dataclasses.dataclass
+ class SpecialWidget(Widget):
+ __sa_dataclass_metadata_key__ = "sa"
+
+ magic: bool = dataclasses.field(
+ default=False, metadata={"sa": Column(Boolean)}
+ )
+
+ __mapper_args__ = dict(
+ polymorphic_identity="special",
+ )
+
+ @dataclasses.dataclass
+ class SurrogateAccountPK:
+
+ __sa_dataclass_metadata_key__ = "sa"
+
+ account_id = Column(
+ "we_dont_want_to_use_this", Integer, primary_key=True
+ )
+
+ @declarative
+ @dataclasses.dataclass
+ class Account(SurrogateAccountPK):
+ __tablename__ = "accounts"
+ __sa_dataclass_metadata_key__ = "sa"
+
+ account_id: int = dataclasses.field(
+ metadata={"sa": Column(Integer, primary_key=True)},
+ )
+ widgets: List[Widget] = dataclasses.field(
+ default_factory=list, metadata={"sa": relationship("Widget")}
+ )
+ widget_count: int = dataclasses.field(
+ init=False,
+ metadata={
+ "sa": Column("widget_count", Integer, nullable=False)
+ },
+ )
+
+ def __post_init__(self):
+ self.widget_count = len(self.widgets)
+
+ def add_widget(self, widget: Widget):
+ self.widgets.append(widget)
+ self.widget_count += 1
+
+ cls.classes.Account = Account
+ cls.classes.Widget = Widget
+ cls.classes.SpecialWidget = SpecialWidget
+
+ def check_widget_dataclass(self, obj):
+ assert dataclasses.is_dataclass(obj)
+ (
+ id_,
+ name,
+ ) = dataclasses.fields(obj)
+ eq_(name.name, "name")
+ eq_(id_.name, "widget_id")
+
+ def check_special_widget_dataclass(self, obj):
+ assert dataclasses.is_dataclass(obj)
+ id_, name, magic = dataclasses.fields(obj)
+ eq_(id_.name, "widget_id")
+ eq_(name.name, "name")
+ eq_(magic.name, "magic")
+
+ def test_asdict_and_astuple_widget(self):
+ Widget = self.classes.Widget
+
+ widget = Widget("Foo")
+ eq_(dataclasses.asdict(widget), {"name": "Foo", "widget_id": None})
+ eq_(
+ dataclasses.astuple(widget),
+ (
+ None,
+ "Foo",
+ ),
+ )
+
+ def test_asdict_and_astuple_special_widget(self):
+ SpecialWidget = self.classes.SpecialWidget
+ widget = SpecialWidget("Bar", magic=True)
+ eq_(
+ dataclasses.asdict(widget),
+ {"name": "Bar", "magic": True, "widget_id": None},
+ )
+ eq_(dataclasses.astuple(widget), (None, "Bar", True))
+
+
+class PropagationBlockTest(fixtures.TestBase):
+ __requires__ = ("dataclasses",)
+
+ run_setup_classes = "each"
+ run_setup_mappers = "each"
+
+ def test_propagate_w_plain_mixin_col(self, run_test):
+ @dataclasses.dataclass
+ class CommonMixin:
+ __sa_dataclass_metadata_key__ = "sa"
+
+ @declared_attr
+ def __tablename__(cls):
+ return cls.__name__.lower()
+
+ __table_args__ = {"mysql_engine": "InnoDB"}
+ timestamp = Column(Integer)
+
+ run_test(CommonMixin)
+
+ def test_propagate_w_field_mixin_col(self, run_test):
+ @dataclasses.dataclass
+ class CommonMixin:
+ __sa_dataclass_metadata_key__ = "sa"
+
+ @declared_attr
+ def __tablename__(cls):
+ return cls.__name__.lower()
+
+ __table_args__ = {"mysql_engine": "InnoDB"}
+
+ timestamp: int = dataclasses.field(
+ init=False,
+ metadata={"sa": Column(Integer, nullable=False)},
+ )
+
+ run_test(CommonMixin)
+
+ @testing.fixture()
+ def run_test(self):
+ def go(CommonMixin):
+ declarative = registry().mapped
+
+ @declarative
+ @dataclasses.dataclass
+ class BaseType(CommonMixin):
+
+ discriminator = Column("type", String(50))
+ __mapper_args__ = dict(polymorphic_on=discriminator)
+ id = Column(Integer, primary_key=True)
+ value = Column(Integer())
+
+ @declarative
+ @dataclasses.dataclass
+ class Single(BaseType):
+
+ __tablename__ = None
+ __mapper_args__ = dict(polymorphic_identity="type1")
+
+ @declarative
+ @dataclasses.dataclass
+ class Joined(BaseType):
+
+ __mapper_args__ = dict(polymorphic_identity="type2")
+ id = Column(
+ Integer, ForeignKey("basetype.id"), primary_key=True
+ )
+
+ eq_(BaseType.__table__.name, "basetype")
+ eq_(
+ list(BaseType.__table__.c.keys()),
+ ["timestamp", "type", "id", "value"],
+ )
+ eq_(BaseType.__table__.kwargs, {"mysql_engine": "InnoDB"})
+ assert Single.__table__ is BaseType.__table__
+ eq_(Joined.__table__.name, "joined")
+ eq_(list(Joined.__table__.c.keys()), ["id"])
+ eq_(Joined.__table__.kwargs, {"mysql_engine": "InnoDB"})
+
+ yield go
+
+ clear_mappers()