]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Update generic associations examples for strict typing
authorMike Fiedler <miketheman@gmail.com>
Tue, 20 Jan 2026 18:22:46 +0000 (13:22 -0500)
committersqla-tester <sqla-tester@sqlalchemy.org>
Tue, 20 Jan 2026 18:22:46 +0000 (13:22 -0500)
### Description

Following previous work in #10450 and #12031, add more type hints to the examples.

I added a new test case to exercise these examples in the future as well.

### Checklist

This pull request is:

- [x] A documentation / typographical / small typing error fix
- Good to go, no issue or tests are needed

**Have a nice day!**

Closes: #13082
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13082
Pull-request-sha: 7a2bed3f6d9fe33c621ac2a6ca56364fa5fe7053

Change-Id: Ia143abbc4c8491c2976203d1b99162652b26b417

examples/generic_associations/discriminator_on_association.py
examples/generic_associations/generic_fk.py
examples/generic_associations/table_per_association.py
examples/generic_associations/table_per_related.py
test/typing/test_mypy.py

index 850bcb4f0634afa29e33e5ffc7f8152718fdbe1b..ed32b7a788416bf9e766b6ad82d3880d725c6bc9 100644 (file)
@@ -16,9 +16,15 @@ objects, but is also slightly more complex.
 
 """
 
+from __future__ import annotations
+
+from typing import Any
+from typing import TYPE_CHECKING
+
 from sqlalchemy import create_engine
 from sqlalchemy import ForeignKey
 from sqlalchemy.ext.associationproxy import association_proxy
+from sqlalchemy.ext.associationproxy import AssociationProxy
 from sqlalchemy.orm import backref
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import declared_attr
@@ -33,8 +39,8 @@ class Base(DeclarativeBase):
     and surrogate primary key column.
     """
 
-    @declared_attr
-    def __tablename__(cls):
+    @declared_attr.directive
+    def __tablename__(cls) -> str:
         return cls.__name__.lower()
 
     id: Mapped[int] = mapped_column(primary_key=True)
@@ -49,7 +55,7 @@ class AddressAssociation(Base):
 
     discriminator: Mapped[str] = mapped_column()
     """Refers to the type of parent."""
-    addresses: Mapped[list["Address"]] = relationship(
+    addresses: Mapped[list[Address]] = relationship(
         back_populates="association"
     )
 
@@ -69,13 +75,15 @@ class Address(Base):
     street: Mapped[str]
     city: Mapped[str]
     zip: Mapped[str]
-    association: Mapped["AddressAssociation"] = relationship(
+    association: Mapped[AddressAssociation] = relationship(
         back_populates="addresses"
     )
 
-    parent = association_proxy("association", "parent")
+    parent: AssociationProxy[HasAddresses] = association_proxy(
+        "association", "parent"
+    )
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "%s(street=%r, city=%r, zip=%r)" % (
             self.__class__.__name__,
             self.street,
@@ -89,12 +97,15 @@ class HasAddresses:
     the address_association table for each parent.
     """
 
+    if TYPE_CHECKING:
+        addresses: AssociationProxy[list[Address]]
+
     @declared_attr
-    def address_association_id(cls) -> Mapped[int]:
+    def address_association_id(cls: type[Any]) -> Mapped[int]:
         return mapped_column(ForeignKey("address_association.id"))
 
     @declared_attr
-    def address_association(cls):
+    def address_association(cls: type[Any]) -> Mapped[AddressAssociation]:
         name = cls.__name__
         discriminator = name.lower()
 
index f82ad635160018c8301e8f3b778afd8c51edb913..fd8d067e307f97f2fbb1e15ec27ded0b7911ef49 100644 (file)
@@ -18,6 +18,12 @@ or "table_per_association" instead of this approach.
 
 """
 
+from __future__ import annotations
+
+from typing import Any
+from typing import cast
+from typing import TYPE_CHECKING
+
 from sqlalchemy import and_
 from sqlalchemy import create_engine
 from sqlalchemy import event
@@ -27,6 +33,7 @@ from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import foreign
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import Mapper
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import remote
 from sqlalchemy.orm import Session
@@ -37,8 +44,8 @@ class Base(DeclarativeBase):
     and surrogate primary key column.
     """
 
-    @declared_attr
-    def __tablename__(cls):
+    @declared_attr.directive
+    def __tablename__(cls) -> str:
         return cls.__name__.lower()
 
     id: Mapped[int] = mapped_column(primary_key=True)
@@ -65,13 +72,15 @@ class Address(Base):
     """
 
     @property
-    def parent(self):
+    def parent(self) -> HasAddresses:
         """Provides in-Python access to the "parent" by choosing
         the appropriate relationship.
         """
-        return getattr(self, f"parent_{self.discriminator}")
+        return cast(
+            HasAddresses, getattr(self, f"parent_{self.discriminator}")
+        )
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "%s(street=%r, city=%r, zip=%r)" % (
             self.__class__.__name__,
             self.street,
@@ -86,9 +95,12 @@ class HasAddresses:
 
     """
 
+    if TYPE_CHECKING:
+        addresses: Mapped[list[Address]]
+
 
 @event.listens_for(HasAddresses, "mapper_configured", propagate=True)
-def setup_listener(mapper, class_):
+def setup_listener(mapper: Mapper[Any], class_: type[Any]) -> None:
     name = class_.__name__
     discriminator = name.lower()
     class_.addresses = relationship(
@@ -106,7 +118,9 @@ def setup_listener(mapper, class_):
     )
 
     @event.listens_for(class_.addresses, "append")
-    def append_address(target, value, initiator):
+    def append_address(
+        target: HasAddresses, value: Address, initiator: Any
+    ) -> None:
         value.discriminator = discriminator
 
 
index 1b75d670c1f27a776cedfd00050caf0295fb1fa1..2d03532d8fd09bc2a31850bcce885cdc900c55a7 100644 (file)
@@ -12,6 +12,8 @@ has no dependency on the system.
 
 """
 
+from __future__ import annotations
+
 from sqlalchemy import Column
 from sqlalchemy import create_engine
 from sqlalchemy import ForeignKey
@@ -29,8 +31,8 @@ class Base(DeclarativeBase):
     and surrogate primary key column.
     """
 
-    @declared_attr
-    def __tablename__(cls):
+    @declared_attr.directive
+    def __tablename__(cls) -> str:
         return cls.__name__.lower()
 
     id: Mapped[int] = mapped_column(primary_key=True)
@@ -47,7 +49,7 @@ class Address(Base):
     city: Mapped[str]
     zip: Mapped[str]
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "%s(street=%r, city=%r, zip=%r)" % (
             self.__class__.__name__,
             self.street,
@@ -63,7 +65,7 @@ class HasAddresses:
     """
 
     @declared_attr
-    def addresses(cls):
+    def addresses(cls: type[DeclarativeBase]) -> Mapped[list[Address]]:
         address_association = Table(
             "%s_addresses" % cls.__tablename__,
             cls.metadata,
index bd4e7d61d1b72f6fcbf948ca24e7f1ddc0da1ece..bd3311d844bd5b3111a882143cf61e2d4b8449b9 100644 (file)
@@ -17,6 +17,11 @@ is completely automated.
 
 """
 
+from __future__ import annotations
+
+from typing import Any
+from typing import TYPE_CHECKING
+
 from sqlalchemy import create_engine
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
@@ -34,8 +39,8 @@ class Base(DeclarativeBase):
 
     """
 
-    @declared_attr
-    def __tablename__(cls):
+    @declared_attr.directive
+    def __tablename__(cls) -> str:
         return cls.__name__.lower()
 
     id: Mapped[int] = mapped_column(primary_key=True)
@@ -55,7 +60,7 @@ class Address:
     city: Mapped[str]
     zip: Mapped[str]
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "%s(street=%r, city=%r, zip=%r)" % (
             self.__class__.__name__,
             self.street,
@@ -64,6 +69,23 @@ class Address:
         )
 
 
+if TYPE_CHECKING:
+
+    class AddressWithParent(Address):
+        """Type stub for Address subclasses created by HasAddresses.
+
+        Inherits street, city, zip from Address.
+
+        Allows mypy to understand when <class>.Address is created,
+        it will have `parent_id` and `parent` attributes.
+        If you won't use `parent_id` attribute directly,
+        there's no need to specify here, included for completeness.
+        """
+
+        parent_id: int
+        parent: HasAddresses
+
+
 class HasAddresses:
     """HasAddresses mixin, creates a new Address class
     for each parent.
@@ -71,7 +93,7 @@ class HasAddresses:
     """
 
     @declared_attr
-    def addresses(cls):
+    def addresses(cls: type[Any]) -> Mapped[list[AddressWithParent]]:
         cls.Address = type(
             f"{cls.__name__}Address",
             (Address, Base),
index 14d13bd6f5d782db2ead12d5bfe864c609390165..4bea968fb1934bbee2a3b88a7837a2657d13ea7b 100644 (file)
@@ -1,4 +1,5 @@
 import os
+from pathlib import Path
 
 from sqlalchemy import testing
 from sqlalchemy.testing import fixtures
@@ -15,3 +16,22 @@ class MypyPlainTest(fixtures.MypyTest):
     )
     def test_mypy_no_plugin(self, mypy_typecheck_file, path):
         mypy_typecheck_file(path)
+
+
+class MypyExamplesTest(fixtures.MypyTest):
+    """Test that examples pass mypy strict mode."""
+
+    # Path to examples/generic_associations relative to repo root
+    _examples_path = Path(__file__).parent.parent.parent / "examples"
+
+    @testing.combinations(
+        *(
+            (path.name, str(path))
+            for path in (_examples_path / "generic_associations").glob("*.py")
+            if path.name != "__init__.py"
+        ),
+        argnames="path",
+        id_="ia",
+    )
+    def test_generic_associations_examples(self, mypy_typecheck_file, path):
+        mypy_typecheck_file(path)