From e35bcc679eb2e4404e72ee8628a3e32fb1ebd71c Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Fri, 25 Oct 2024 10:53:58 -0700 Subject: [PATCH] refactor: update dict_of_sets_with_default to declarative Signed-off-by: Mike Fiedler --- .../association/dict_of_sets_with_default.py | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/examples/association/dict_of_sets_with_default.py b/examples/association/dict_of_sets_with_default.py index f515ab975b..94b50bb9bf 100644 --- a/examples/association/dict_of_sets_with_default.py +++ b/examples/association/dict_of_sets_with_default.py @@ -12,29 +12,28 @@ upon access of a non-existent key, in the same manner as Python's """ +from __future__ import annotations + import operator -from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey -from sqlalchemy import Integer -from sqlalchemy import String from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.associationproxy import AssociationProxy +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm.collections import KeyFuncDict -class Base: - id = Column(Integer, primary_key=True) - +class Base(DeclarativeBase): + id: Mapped[int] = mapped_column(primary_key=True) -Base = declarative_base(cls=Base) - -class GenDefaultCollection(KeyFuncDict): - def __missing__(self, key): +class GenDefaultCollection(KeyFuncDict[str, "B"]): + def __missing__(self, key: str) -> B: self[key] = b = B(key) return b @@ -56,15 +55,17 @@ class A(Base): class B(Base): __tablename__ = "b" - a_id = Column(Integer, ForeignKey("a.id"), nullable=False) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) elements = relationship("C", collection_class=set) - key = Column(String) + key: Mapped[str] - values = association_proxy("elements", "value") + values: AssociationProxy[list[str]] = association_proxy( + "elements", "value" + ) """Bridge the association from 'elements' over to the 'value' element of C.""" - def __init__(self, key, values=None): + def __init__(self, key: str, values: list[str] | None = None) -> None: self.key = key if values: self.values = values @@ -72,10 +73,10 @@ class B(Base): class C(Base): __tablename__ = "c" - b_id = Column(Integer, ForeignKey("b.id"), nullable=False) - value = Column(Integer) + b_id: Mapped[int] = mapped_column(ForeignKey("b.id")) + value: Mapped[int] - def __init__(self, value): + def __init__(self, value: int) -> None: self.value = value @@ -90,7 +91,7 @@ if __name__ == "__main__": session.add_all([A(collections={"1": {1, 2, 3}})]) session.commit() - a1 = session.query(A).first() + a1 = session.query(A).one() print(a1.collections["1"]) a1.collections["1"].add(4) session.commit() -- 2.47.3