]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refactor: update dict_of_sets_with_default to declarative
authorMike Fiedler <miketheman@gmail.com>
Fri, 25 Oct 2024 17:53:58 +0000 (10:53 -0700)
committerMike Fiedler <miketheman@gmail.com>
Mon, 18 Nov 2024 18:45:38 +0000 (13:45 -0500)
Signed-off-by: Mike Fiedler <miketheman@gmail.com>
examples/association/dict_of_sets_with_default.py

index f515ab975b5b8839a71372d6b323ca6df7dc26f2..94b50bb9bfe5871a8276f29d909dc341f91703c3 100644 (file)
@@ -12,29 +12,28 @@ upon access of a non-existent key, in the same manner as Python's
 
 """
 
+from __future__ import annotations
+
 import operator
 
-from sqlalchemy import Column
 from sqlalchemy import create_engine
 from sqlalchemy import ForeignKey
-from sqlalchemy import Integer
-from sqlalchemy import String
 from sqlalchemy.ext.associationproxy import association_proxy
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext.associationproxy import AssociationProxy
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm.collections import KeyFuncDict
 
 
-class Base:
-    id = Column(Integer, primary_key=True)
-
+class Base(DeclarativeBase):
+    id: Mapped[int] = mapped_column(primary_key=True)
 
-Base = declarative_base(cls=Base)
 
-
-class GenDefaultCollection(KeyFuncDict):
-    def __missing__(self, key):
+class GenDefaultCollection(KeyFuncDict[str, "B"]):
+    def __missing__(self, key: str) -> B:
         self[key] = b = B(key)
         return b
 
@@ -56,15 +55,17 @@ class A(Base):
 
 class B(Base):
     __tablename__ = "b"
-    a_id = Column(Integer, ForeignKey("a.id"), nullable=False)
+    a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
     elements = relationship("C", collection_class=set)
-    key = Column(String)
+    key: Mapped[str]
 
-    values = association_proxy("elements", "value")
+    values: AssociationProxy[list[str]] = association_proxy(
+        "elements", "value"
+    )
     """Bridge the association from 'elements' over to the
     'value' element of C."""
 
-    def __init__(self, key, values=None):
+    def __init__(self, key: str, values: list[str] | None = None) -> None:
         self.key = key
         if values:
             self.values = values
@@ -72,10 +73,10 @@ class B(Base):
 
 class C(Base):
     __tablename__ = "c"
-    b_id = Column(Integer, ForeignKey("b.id"), nullable=False)
-    value = Column(Integer)
+    b_id: Mapped[int] = mapped_column(ForeignKey("b.id"))
+    value: Mapped[int]
 
-    def __init__(self, value):
+    def __init__(self, value: int) -> None:
         self.value = value
 
 
@@ -90,7 +91,7 @@ if __name__ == "__main__":
     session.add_all([A(collections={"1": {1, 2, 3}})])
     session.commit()
 
-    a1 = session.query(A).first()
+    a1 = session.query(A).one()
     print(a1.collections["1"])
     a1.collections["1"].add(4)
     session.commit()