]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Update association examples to Declarative API
authorMike Fiedler <miketheman@gmail.com>
Thu, 26 Jun 2025 19:04:03 +0000 (15:04 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 26 Jun 2025 19:17:59 +0000 (21:17 +0200)
### Description

Follows initial attempt in #10450 - but starts with simpler association examples.

### Checklist
This pull request is:

- [x] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed

I was curious how to add these selectively to any of the type hint test suites, to prevent future drift, but didn't see anything too obvious.

Closes: #12031
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12031
Pull-request-sha: dad6239370f23d52b4c0a1b21eba5752e216207e

Change-Id: Id5c2d65137c5e9d7e87778acd51b965c2bcf315a

examples/association/basic_association.py
examples/association/dict_of_sets_with_default.py
examples/association/proxied_association.py

index 7a5b46097e3c4006e2b15de48c7c4e63a5bc3efa..1ef1f698d331b28589720698d159c68a331c5f26 100644 (file)
@@ -10,104 +10,116 @@ of "items", with a particular price paid associated with each "item".
 
 """
 
+from __future__ import annotations
+
 from datetime import datetime
 
-from sqlalchemy import and_
-from sqlalchemy import Column
 from sqlalchemy import create_engine
-from sqlalchemy import DateTime
-from sqlalchemy import Float
 from sqlalchemy import ForeignKey
-from sqlalchemy import Integer
+from sqlalchemy import select
 from sqlalchemy import String
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 
 
-Base = declarative_base()
+class Base(DeclarativeBase):
+    pass
 
 
 class Order(Base):
     __tablename__ = "order"
 
-    order_id = Column(Integer, primary_key=True)
-    customer_name = Column(String(30), nullable=False)
-    order_date = Column(DateTime, nullable=False, default=datetime.now())
-    order_items = relationship(
-        "OrderItem", cascade="all, delete-orphan", backref="order"
+    order_id: Mapped[int] = mapped_column(primary_key=True)
+    customer_name: Mapped[str] = mapped_column(String(30))
+    order_date: Mapped[datetime] = mapped_column(default=datetime.now())
+    order_items: Mapped[list[OrderItem]] = relationship(
+        cascade="all, delete-orphan", backref="order"
     )
 
-    def __init__(self, customer_name):
+    def __init__(self, customer_name: str) -> None:
         self.customer_name = customer_name
 
 
 class Item(Base):
     __tablename__ = "item"
-    item_id = Column(Integer, primary_key=True)
-    description = Column(String(30), nullable=False)
-    price = Column(Float, nullable=False)
+    item_id: Mapped[int] = mapped_column(primary_key=True)
+    description: Mapped[str] = mapped_column(String(30))
+    price: Mapped[float]
 
-    def __init__(self, description, price):
+    def __init__(self, description: str, price: float) -> None:
         self.description = description
         self.price = price
 
-    def __repr__(self):
-        return "Item(%r, %r)" % (self.description, self.price)
+    def __repr__(self) -> str:
+        return "Item({!r}, {!r})".format(self.description, self.price)
 
 
 class OrderItem(Base):
     __tablename__ = "orderitem"
-    order_id = Column(Integer, ForeignKey("order.order_id"), primary_key=True)
-    item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
-    price = Column(Float, nullable=False)
+    order_id: Mapped[int] = mapped_column(
+        ForeignKey("order.order_id"), primary_key=True
+    )
+    item_id: Mapped[int] = mapped_column(
+        ForeignKey("item.item_id"), primary_key=True
+    )
+    price: Mapped[float]
 
-    def __init__(self, item, price=None):
+    def __init__(self, item: Item, price: float | None = None) -> None:
         self.item = item
         self.price = price or item.price
 
-    item = relationship(Item, lazy="joined")
+    item: Mapped[Item] = relationship(lazy="joined")
 
 
 if __name__ == "__main__":
     engine = create_engine("sqlite://")
     Base.metadata.create_all(engine)
 
-    session = Session(engine)
-
-    # create catalog
-    tshirt, mug, hat, crowbar = (
-        Item("SA T-Shirt", 10.99),
-        Item("SA Mug", 6.50),
-        Item("SA Hat", 8.99),
-        Item("MySQL Crowbar", 16.99),
-    )
-    session.add_all([tshirt, mug, hat, crowbar])
-    session.commit()
-
-    # create an order
-    order = Order("john smith")
-
-    # add three OrderItem associations to the Order and save
-    order.order_items.append(OrderItem(mug))
-    order.order_items.append(OrderItem(crowbar, 10.99))
-    order.order_items.append(OrderItem(hat))
-    session.add(order)
-    session.commit()
-
-    # query the order, print items
-    order = session.query(Order).filter_by(customer_name="john smith").one()
-    print(
-        [
-            (order_item.item.description, order_item.price)
-            for order_item in order.order_items
-        ]
-    )
-
-    # print customers who bought 'MySQL Crowbar' on sale
-    q = session.query(Order).join(OrderItem).join(Item)
-    q = q.filter(
-        and_(Item.description == "MySQL Crowbar", Item.price > OrderItem.price)
-    )
-
-    print([order.customer_name for order in q])
+    with Session(engine) as session:
+
+        # create catalog
+        tshirt, mug, hat, crowbar = (
+            Item("SA T-Shirt", 10.99),
+            Item("SA Mug", 6.50),
+            Item("SA Hat", 8.99),
+            Item("MySQL Crowbar", 16.99),
+        )
+        session.add_all([tshirt, mug, hat, crowbar])
+        session.commit()
+
+        # create an order
+        order = Order("john smith")
+
+        # add three OrderItem associations to the Order and save
+        order.order_items.append(OrderItem(mug))
+        order.order_items.append(OrderItem(crowbar, 10.99))
+        order.order_items.append(OrderItem(hat))
+        session.add(order)
+        session.commit()
+
+        # query the order, print items
+        order = session.scalars(
+            select(Order).filter_by(customer_name="john smith")
+        ).one()
+        print(
+            [
+                (order_item.item.description, order_item.price)
+                for order_item in order.order_items
+            ]
+        )
+
+        # print customers who bought 'MySQL Crowbar' on sale
+        q = (
+            select(Order)
+            .join(OrderItem)
+            .join(Item)
+            .where(
+                Item.description == "MySQL Crowbar",
+                Item.price > OrderItem.price,
+            )
+        )
+
+        print([order.customer_name for order in session.scalars(q)])
index f515ab975b5b8839a71372d6b323ca6df7dc26f2..fef3c1d57a2d075e51570321486098a8f4caf9a9 100644 (file)
@@ -12,43 +12,46 @@ upon access of a non-existent key, in the same manner as Python's
 
 """
 
+from __future__ import annotations
+
 import operator
+from typing import Mapping
 
-from sqlalchemy import Column
 from sqlalchemy import create_engine
 from sqlalchemy import ForeignKey
-from sqlalchemy import Integer
-from sqlalchemy import String
+from sqlalchemy import select
 from sqlalchemy.ext.associationproxy import association_proxy
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext.associationproxy import AssociationProxy
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.orm.collections import KeyFuncDict
 
 
-class Base:
-    id = Column(Integer, primary_key=True)
-
+class Base(DeclarativeBase):
+    id: Mapped[int] = mapped_column(primary_key=True)
 
-Base = declarative_base(cls=Base)
 
-
-class GenDefaultCollection(KeyFuncDict):
-    def __missing__(self, key):
+class GenDefaultCollection(KeyFuncDict[str, "B"]):
+    def __missing__(self, key: str) -> B:
         self[key] = b = B(key)
         return b
 
 
 class A(Base):
     __tablename__ = "a"
-    associations = relationship(
+    associations: Mapped[Mapping[str, B]] = relationship(
         "B",
         collection_class=lambda: GenDefaultCollection(
             operator.attrgetter("key")
         ),
     )
 
-    collections = association_proxy("associations", "values")
+    collections: AssociationProxy[dict[str, set[int]]] = association_proxy(
+        "associations", "values"
+    )
     """Bridge the association from 'associations' over to the 'values'
     association proxy of B.
     """
@@ -56,15 +59,15 @@ class A(Base):
 
 class B(Base):
     __tablename__ = "b"
-    a_id = Column(Integer, ForeignKey("a.id"), nullable=False)
-    elements = relationship("C", collection_class=set)
-    key = Column(String)
+    a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+    elements: Mapped[set[C]] = relationship("C", collection_class=set)
+    key: Mapped[str]
 
-    values = association_proxy("elements", "value")
+    values: AssociationProxy[set[int]] = association_proxy("elements", "value")
     """Bridge the association from 'elements' over to the
     'value' element of C."""
 
-    def __init__(self, key, values=None):
+    def __init__(self, key: str, values: set[int] | None = None) -> None:
         self.key = key
         if values:
             self.values = values
@@ -72,10 +75,10 @@ class B(Base):
 
 class C(Base):
     __tablename__ = "c"
-    b_id = Column(Integer, ForeignKey("b.id"), nullable=False)
-    value = Column(Integer)
+    b_id: Mapped[int] = mapped_column(ForeignKey("b.id"))
+    value: Mapped[int]
 
-    def __init__(self, value):
+    def __init__(self, value: int) -> None:
         self.value = value
 
 
@@ -90,7 +93,7 @@ if __name__ == "__main__":
     session.add_all([A(collections={"1": {1, 2, 3}})])
     session.commit()
 
-    a1 = session.query(A).first()
+    a1 = session.scalars(select(A)).one()
     print(a1.collections["1"])
     a1.collections["1"].add(4)
     session.commit()
index 65dcd6c0b66e9cddb5e9442eaea20e6903fc558a..0f18e167eba35030e8d0bbd7aeefd002f2cbec58 100644 (file)
@@ -5,116 +5,127 @@ to ``OrderItem`` optional.
 
 """
 
+from __future__ import annotations
+
 from datetime import datetime
 
-from sqlalchemy import Column
 from sqlalchemy import create_engine
-from sqlalchemy import DateTime
-from sqlalchemy import Float
 from sqlalchemy import ForeignKey
-from sqlalchemy import Integer
+from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy.ext.associationproxy import association_proxy
-from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.ext.associationproxy import AssociationProxy
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 
 
-Base = declarative_base()
+class Base(DeclarativeBase):
+    pass
 
 
 class Order(Base):
     __tablename__ = "order"
 
-    order_id = Column(Integer, primary_key=True)
-    customer_name = Column(String(30), nullable=False)
-    order_date = Column(DateTime, nullable=False, default=datetime.now())
-    order_items = relationship(
-        "OrderItem", cascade="all, delete-orphan", backref="order"
+    order_id: Mapped[int] = mapped_column(primary_key=True)
+    customer_name: Mapped[str] = mapped_column(String(30))
+    order_date: Mapped[datetime] = mapped_column(default=datetime.now())
+    order_items: Mapped[list[OrderItem]] = relationship(
+        cascade="all, delete-orphan", backref="order"
+    )
+    items: AssociationProxy[list[Item]] = association_proxy(
+        "order_items", "item"
     )
-    items = association_proxy("order_items", "item")
 
-    def __init__(self, customer_name):
+    def __init__(self, customer_name: str) -> None:
         self.customer_name = customer_name
 
 
 class Item(Base):
     __tablename__ = "item"
-    item_id = Column(Integer, primary_key=True)
-    description = Column(String(30), nullable=False)
-    price = Column(Float, nullable=False)
+    item_id: Mapped[int] = mapped_column(primary_key=True)
+    description: Mapped[str] = mapped_column(String(30))
+    price: Mapped[float]
 
-    def __init__(self, description, price):
+    def __init__(self, description: str, price: float) -> None:
         self.description = description
         self.price = price
 
-    def __repr__(self):
-        return "Item(%r, %r)" % (self.description, self.price)
+    def __repr__(self) -> str:
+        return "Item({!r}, {!r})".format(self.description, self.price)
 
 
 class OrderItem(Base):
     __tablename__ = "orderitem"
-    order_id = Column(Integer, ForeignKey("order.order_id"), primary_key=True)
-    item_id = Column(Integer, ForeignKey("item.item_id"), primary_key=True)
-    price = Column(Float, nullable=False)
+    order_id: Mapped[int] = mapped_column(
+        ForeignKey("order.order_id"), primary_key=True
+    )
+    item_id: Mapped[int] = mapped_column(
+        ForeignKey("item.item_id"), primary_key=True
+    )
+    price: Mapped[float]
+
+    item: Mapped[Item] = relationship(lazy="joined")
 
-    def __init__(self, item, price=None):
+    def __init__(self, item: Item, price: float | None = None):
         self.item = item
         self.price = price or item.price
 
-    item = relationship(Item, lazy="joined")
-
 
 if __name__ == "__main__":
     engine = create_engine("sqlite://")
     Base.metadata.create_all(engine)
 
-    session = Session(engine)
-
-    # create catalog
-    tshirt, mug, hat, crowbar = (
-        Item("SA T-Shirt", 10.99),
-        Item("SA Mug", 6.50),
-        Item("SA Hat", 8.99),
-        Item("MySQL Crowbar", 16.99),
-    )
-    session.add_all([tshirt, mug, hat, crowbar])
-    session.commit()
-
-    # create an order
-    order = Order("john smith")
-
-    # add items via the association proxy.
-    # the OrderItem is created automatically.
-    order.items.append(mug)
-    order.items.append(hat)
-
-    # add an OrderItem explicitly.
-    order.order_items.append(OrderItem(crowbar, 10.99))
-
-    session.add(order)
-    session.commit()
-
-    # query the order, print items
-    order = session.query(Order).filter_by(customer_name="john smith").one()
-
-    # print items based on the OrderItem collection directly
-    print(
-        [
-            (assoc.item.description, assoc.price, assoc.item.price)
-            for assoc in order.order_items
-        ]
-    )
-
-    # print items based on the "proxied" items collection
-    print([(item.description, item.price) for item in order.items])
-
-    # print customers who bought 'MySQL Crowbar' on sale
-    orders = (
-        session.query(Order)
-        .join(OrderItem)
-        .join(Item)
-        .filter(Item.description == "MySQL Crowbar")
-        .filter(Item.price > OrderItem.price)
-    )
-    print([o.customer_name for o in orders])
+    with Session(engine) as session:
+
+        # create catalog
+        tshirt, mug, hat, crowbar = (
+            Item("SA T-Shirt", 10.99),
+            Item("SA Mug", 6.50),
+            Item("SA Hat", 8.99),
+            Item("MySQL Crowbar", 16.99),
+        )
+        session.add_all([tshirt, mug, hat, crowbar])
+        session.commit()
+
+        # create an order
+        order = Order("john smith")
+
+        # add items via the association proxy.
+        # the OrderItem is created automatically.
+        order.items.append(mug)
+        order.items.append(hat)
+
+        # add an OrderItem explicitly.
+        order.order_items.append(OrderItem(crowbar, 10.99))
+
+        session.add(order)
+        session.commit()
+
+        # query the order, print items
+        order = session.scalars(
+            select(Order).filter_by(customer_name="john smith")
+        ).one()
+
+        # print items based on the OrderItem collection directly
+        print(
+            [
+                (assoc.item.description, assoc.price, assoc.item.price)
+                for assoc in order.order_items
+            ]
+        )
+
+        # print items based on the "proxied" items collection
+        print([(item.description, item.price) for item in order.items])
+
+        # print customers who bought 'MySQL Crowbar' on sale
+        orders_stmt = (
+            select(Order)
+            .join(OrderItem)
+            .join(Item)
+            .filter(Item.description == "MySQL Crowbar")
+            .filter(Item.price > OrderItem.price)
+        )
+        print([o.customer_name for o in session.scalars(orders_stmt)])